From 02c91e03f975c2a6a05a9d5327057bb6b3c4a66f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 1 Oct 2017 18:42:45 +0900 Subject: [PATCH 001/712] [SPARK-22063][R] Fixes lint check failures in R by latest commit sha1 ID of lint-r ## What changes were proposed in this pull request? Currently, we set lintr to jimhester/lintra769c0b (see [this](https://github.com/apache/spark/commit/7d1175011c976756efcd4e4e4f70a8fd6f287026) and [SPARK-14074](https://issues.apache.org/jira/browse/SPARK-14074)). I first tested and checked lintr-1.0.1 but it looks many important fixes are missing (for example, checking 100 length). So, I instead tried the latest commit, https://github.com/jimhester/lintr/commit/5431140ffea65071f1327625d4a8de9688fa7e72, in my local and fixed the check failures. It looks it has fixed many bugs and now finds many instances that I have observed and thought should be caught time to time, here I filed [the results](https://gist.github.com/HyukjinKwon/4f59ddcc7b6487a02da81800baca533c). The downside looks it now takes about 7ish mins, (it was 2ish mins before) in my local. ## How was this patch tested? Manually, `./dev/lint-r` after manually updating the lintr package. Author: hyukjinkwon Author: zuotingbing Closes #19290 from HyukjinKwon/upgrade-r-lint. --- R/pkg/.lintr | 2 +- R/pkg/R/DataFrame.R | 30 ++-- R/pkg/R/RDD.R | 6 +- R/pkg/R/WindowSpec.R | 2 +- R/pkg/R/column.R | 2 + R/pkg/R/context.R | 2 +- R/pkg/R/deserialize.R | 2 +- R/pkg/R/functions.R | 79 ++++++----- R/pkg/R/generics.R | 4 +- R/pkg/R/group.R | 4 +- R/pkg/R/mllib_classification.R | 137 +++++++++++-------- R/pkg/R/mllib_clustering.R | 15 +- R/pkg/R/mllib_regression.R | 62 +++++---- R/pkg/R/mllib_tree.R | 36 +++-- R/pkg/R/pairRDD.R | 4 +- R/pkg/R/schema.R | 2 +- R/pkg/R/stats.R | 14 +- R/pkg/R/utils.R | 4 +- R/pkg/inst/worker/worker.R | 2 +- R/pkg/tests/fulltests/test_binary_function.R | 2 +- R/pkg/tests/fulltests/test_rdd.R | 6 +- R/pkg/tests/fulltests/test_sparkSQL.R | 14 +- dev/lint-r.R | 4 +- 23 files changed, 242 insertions(+), 193 deletions(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index ae50b28ec6166..c83ad2adfe0ef 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0728141fa483e..176bb3b8a8d0c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1923,13 +1923,15 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param i,subset (Optional) a logical expression to filter on rows. #' For extract operator [[ and replacement operator [[<-, the indexing parameter for #' a single Column. -#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. +#' @param j,select expression for the single Column or a list of columns to select from the +#' SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. #' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. #' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected +#' columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method @@ -2608,12 +2610,12 @@ setMethod("merge", } else { # if by or both by.x and by.y have length 0, use Cartesian Product joinRes <- crossJoin(x, y) - return (joinRes) + return(joinRes) } # sets alias for making colnames unique in dataframes 'x' and 'y' - colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) - colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + colsX <- genAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- genAliasesForIntersectedCols(y, by, suffixes[2]) # selects columns with their aliases from dataframes # in case same column names are present in both data frames @@ -2661,9 +2663,8 @@ setMethod("merge", #' @param intersectedColNames a list of intersected column names of the SparkDataFrame #' @param suffix a suffix for the column name #' @return list of columns -#' -#' @note generateAliasesForIntersectedCols since 1.6.0 -generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { +#' @noRd +genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { allColNames <- names(x) # sets alias for making colnames unique in dataframe 'x' cols <- lapply(allColNames, function(colName) { @@ -2671,7 +2672,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { if (colName %in% intersectedColNames) { newJoin <- paste(colName, suffix, sep = "") if (newJoin %in% allColNames){ - stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.") } col <- alias(col, newJoin) @@ -3058,7 +3059,8 @@ setMethod("describe", #' summary(select(df, "age", "height")) #' } #' @note summary(SparkDataFrame) since 1.5.0 -#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for +#' previous defaults. #' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), @@ -3765,8 +3767,8 @@ setMethod("checkpoint", #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{cube} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. @@ -3800,8 +3802,8 @@ setMethod("cube", #' #' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 15ca212acf87f..6e89b4bb4d964 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -131,7 +131,7 @@ PipelinedRDD <- function(prev, func) { # Return the serialization mode for an RDD. setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) # For normal RDDs we can directly read the serializedMode -setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode) # For pipelined RDDs if jrdd_val is set then serializedMode should exist # if not we return the defaultSerialization mode of "byte" as we don't know the serialization # mode at this point in time. @@ -145,7 +145,7 @@ setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), }) # The jrdd accessor function. -setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd) setMethod("getJRDD", signature(rdd = "PipelinedRDD"), function(rdd, serializedMode = "byte") { if (!is.null(rdd@env$jrdd_val)) { @@ -893,7 +893,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- stats::rpois(1, fraction) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 81beac9ea9925..debc7cbde55e7 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -73,7 +73,7 @@ setMethod("show", "WindowSpec", setMethod("partitionBy", signature(x = "WindowSpec"), function(x, col, ...) { - stopifnot (class(col) %in% c("character", "Column")) + stopifnot(class(col) %in% c("character", "Column")) if (class(col) == "character") { windowSpec(callJMethod(x@sws, "partitionBy", col, list(...))) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a5c2ea81f2490..3095adb918b67 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -238,8 +238,10 @@ setMethod("between", signature(x = "Column"), #' @param x a Column. #' @param dataType a character object describing the target data type. #' See +# nolint start #' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ #' Spark Data Types} for available data types. +# nolint end #' @rdname cast #' @name cast #' @family colum_func diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8349b57a30a93..443c2ff8f9ace 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -329,7 +329,7 @@ spark.addFile <- function(path, recursive = FALSE) { #' spark.getSparkFilesRootDirectory() #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 -spark.getSparkFilesRootDirectory <- function() { +spark.getSparkFilesRootDirectory <- function() { # nolint if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { # Running on driver. callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..a90f7d381026b 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -43,7 +43,7 @@ readObject <- function(con) { } readTypedObject <- function(con, type) { - switch (type, + switch(type, "i" = readInt(con), "c" = readString(con), "b" = readBoolean(con), diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f286263c2162..0143a3e63ba61 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,7 +38,8 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. +#' @param x Column to compute on. In \code{window}, it must be a time Column of +#' \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse #' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string #' to use to specify the truncation method. For example, "year", "yyyy", "yy" for @@ -90,8 +91,8 @@ NULL #' #' Math functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, -#' this is the number of bits to shift. +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and +#' \code{shiftRightUnsigned}, this is the number of bits to shift. #' @param y Column to compute on. #' @param ... additional argument(s). #' @name column_math_functions @@ -480,7 +481,7 @@ setMethod("ceiling", setMethod("coalesce", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -676,7 +677,7 @@ setMethod("crc32", setMethod("hash", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -1310,9 +1311,9 @@ setMethod("round", #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, -#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left -#' of the decimal point when \code{scale} < 0. +#' @param scale round to \code{scale} digits to the right of the decimal point when +#' \code{scale} > 0, the nearest even number when \code{scale} = 0, and \code{scale} digits +#' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method #' @export @@ -2005,8 +2006,9 @@ setMethod("months_between", signature(y = "Column"), }) #' @details -#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if -#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column +#' (\code{x}) if the first column is NaN. Both inputs should be floating point columns +#' (DoubleType or FloatType). #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method @@ -2061,7 +2063,7 @@ setMethod("approxCountDistinct", setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2090,7 +2092,7 @@ setMethod("countDistinct", setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2110,7 +2112,7 @@ setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2130,7 +2132,7 @@ setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2406,8 +2408,8 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long +#' value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method @@ -2505,9 +2507,10 @@ setMethod("format_string", signature(format = "character", x = "Column"), }) #' @details -#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a -#' string representing the timestamp of that moment in the current system time zone in the JVM in the -#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) +#' to a string representing the timestamp of that moment in the current system time zone in the JVM +#' in the given format. +#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions @@ -2634,8 +2637,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), }) #' @details -#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples -#' from U[0.0, 1.0]. +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) +#' samples from U[0.0, 1.0]. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2664,8 +2667,8 @@ setMethod("rand", signature(seed = "numeric"), }) #' @details -#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from -#' the standard normal distribution. +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples +#' from the standard normal distribution. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2831,8 +2834,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), }) #' @details -#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. -#' For unmatched expressions null is returned. +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result +#' expressions. For unmatched expressions null is returned. #' #' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. @@ -2859,8 +2862,8 @@ setMethod("when", signature(condition = "Column", value = "ANY"), }) #' @details -#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. -#' Otherwise \code{no} is returned for unmatched conditions. +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are +#' satisfied. Otherwise \code{no} is returned for unmatched conditions. #' #' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. @@ -2990,7 +2993,8 @@ setMethod("ntile", }) #' @details -#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window +#' partition. #' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). #' This is equivalent to the \code{PERCENT_RANK} function in SQL. #' The method should be used with no argument. @@ -3160,7 +3164,8 @@ setMethod("posexplode", }) #' @details -#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. +#' \code{create_array}: Creates a new array column. The input columns must all have the same data +#' type. #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method @@ -3169,7 +3174,7 @@ setMethod("posexplode", setMethod("create_array", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3178,8 +3183,8 @@ setMethod("create_array", }) #' @details -#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, -#' e.g. (key1, value1, key2, value2, ...). +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value +#' pairs, e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' @@ -3190,7 +3195,7 @@ setMethod("create_array", setMethod("create_map", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3352,9 +3357,9 @@ setMethod("not", }) #' @details -#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL -#' and \code{grouping} function in Scala. +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or +#' not, returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} +#' in SQL and \code{grouping} function in Scala. #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method @@ -3412,7 +3417,7 @@ setMethod("grouping_bit", setMethod("grouping_id", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0fe8f0453b064..4e427489f6860 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,7 +385,7 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @return A SparkDataFrame. #' @rdname summarize #' @export -setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias #' @@ -731,7 +731,7 @@ setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select #' @export -setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) +setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr #' @export diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 0a7be0e993975..54ef9f07d6fae 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -133,8 +133,8 @@ setMethod("summarize", # Aggregate Functions by name methods <- c("avg", "max", "mean", "min", "sum") -# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", -# "variance", "var_samp", "var_pop" +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", +# "stddev_pop", "variance", "var_samp", "var_pop" #' Pivot a column of the GroupedData and perform the specified aggregation. #' diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 15af8298ba484..7cd072a1d6f89 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -58,22 +58,25 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. -#' @param standardization Whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. +#' @param standardization Whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. #' @param threshold The threshold in binary classification applied to the linear model prediction. #' This threshold can be any real number, where Inf will make all predictions 0.0 #' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -175,62 +178,80 @@ function(object, path, overwrite = FALSE) { #' Logistic Regression Model #' -#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary +#' logistic regression with pivoting; "multinomial": Multinomial logistic (softmax) regression +#' without pivoting, similar to glmnet. Users can print, make predictions on the produced model +#' and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param regParam the regularization parameter. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 +#' penalty. For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, +#' the penalty is a combination of L1 and L2. Default is 0.0 which is an +#' L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param family the name of family which is a description of the label distribution to be used in the model. +#' @param family the name of family which is a description of the label distribution to be used +#' in the model. #' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". #' Else, set to "multinomial".} #' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without +#' pivoting.} #' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. +#' @param standardization whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. Default is TRUE, same as +#' glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of +#' class label 1 is > threshold, then predict 1, else 0. A high threshold +#' encourages the model to predict 0 more often; a low threshold encourages the +#' model to predict 1 more often. Note: Setting this with threshold p is +#' equivalent to setting thresholds c(1-p, p). In multiclass (or binary) +#' classification to adjust the probability of predicting each class. Array must +#' have length equal to the number of classes, with values > 0, excepting that +#' at most one value may be 0. The class with largest value p/t is predicted, +#' where p is the original probability of that class and t is the class's +#' threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. -#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. This is an expert parameter. Default +#' value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. -#' The bounds vector size must be equal to 1 for binomial regression, or the number -#' of classes for multinomial regression. -#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. -#' The bound vector size must be equal to 1 for binomial regression, or the number +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bounds vector size must be equal to 1 for binomial regression, +#' or the number #' of classes for multinomial regression. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bound vector size must be equal to 1 for binomial regression, +#' or the number of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -412,11 +433,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a #' numeric vector. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -452,11 +474,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } if (!is.null(seed)) { seed <- as.character(as.integer(seed)) @@ -538,11 +560,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 97c9fa1b45840..a25bf81c6d977 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -60,9 +60,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @param maxIter maximum iteration number. #' @param seed the random seed. #' @param minDivisibleClusterSize The minimum number of points (if greater than or equal to 1.0) -#' or the minimum proportion of points (if less than 1.0) of a divisible cluster. -#' Note that it is an expert parameter. The default value should be good enough -#' for most cases. +#' or the minimum proportion of points (if less than 1.0) of a +#' divisible cluster. Note that it is an expert parameter. The +#' default value should be good enough for most cases. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.bisectingKmeans} returns a fitted bisecting k-means model. #' @rdname spark.bisectingKmeans @@ -325,10 +325,11 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' Note that the response variable of formula is empty in spark.kmeans. #' @param k number of centers. #' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. +#' @param initMode the initialization algorithm chosen to fit the model. #' @param seed the random seed for cluster initialization. #' @param initSteps the number of steps for the k-means|| initialization mode. -#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' This is an advanced setting, the default of 2 is almost always enough. +#' Must be > 0. #' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. @@ -548,8 +549,8 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' \item{\code{topics}}{top 10 terms and their weights of all topics} #' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file #' used as training set} -#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set, -#' given the current parameter estimates: +#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the +#' training set, given the current parameter estimates: #' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) #' It is only for distributed LDA model (i.e., optimizer = "em")} #' \item{\code{logPrior}}{Log probability of the current parameter estimate: diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ebaeae970218a..f734a0865ec3b 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -58,8 +58,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' Note that there are two ways to specify the tweedie family. #' \itemize{ #' \item Set \code{family = "tweedie"} and specify the var.power and link.power; -#' \item When package \code{statmod} is loaded, the tweedie family is specified using the -#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' \item When package \code{statmod} is loaded, the tweedie family is specified +#' using the family definition therein, i.e., \code{tweedie(var.power, link.power)}. #' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. @@ -71,13 +71,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -197,13 +199,15 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -233,11 +237,11 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes -#' coefficients, standard error of coefficients, t value and p value), +#' The list of components includes at least the \code{coefficients} (coefficients matrix, +#' which includes coefficients, standard error of coefficients, t value and p value), #' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) -#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, -#' the coefficients matrix only provides coefficients. +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in +#' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -457,15 +461,17 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this +#' param could be adjusted to a larger size. This is an expert parameter. +#' Default value should be good for most cases. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 33c4653f4c184..89a58bf0aadae 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -132,10 +132,12 @@ print.summary.decisionTree <- function(x) { #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and #' \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ #' GBT Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ #' GBT Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -164,11 +166,12 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -352,10 +355,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ #' Random Forest Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ #' Random Forest Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -382,11 +387,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -567,10 +573,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ #' Decision Tree Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ #' Decision Tree Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -592,11 +600,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -671,7 +680,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of +#' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method #' @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 8fa21be3076b5..9c2e57d3067db 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -860,7 +860,7 @@ setMethod("subtractByKey", other, numPartitions = numPartitions), filterFunction), - function (v) { v[[1]] }) + function(v) { v[[1]] }) }) #' Return a subset of this RDD sampled by key. @@ -925,7 +925,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- stats::rpois(1, frac) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index d1ed6833d5d02..65f418740c643 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -155,7 +155,7 @@ checkType <- function(type) { } else { # Check complex types firstChar <- substr(type, 1, 1) - switch (firstChar, + switch(firstChar, a = { # Array type m <- regexec("^array<(.+)>$", type) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 9a9fa84044ce6..c8af798830b30 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -29,9 +29,9 @@ setOldClass("jobj") #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of \code{col1} and the column names will be the distinct values -#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs -#' that have no occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct +#' values of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". +#' Pairs that have no occurrences will have zero as their counts. #' #' @rdname crosstab #' @name crosstab @@ -53,8 +53,8 @@ setMethod("crosstab", }) #' @details -#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical -#' columns of \emph{one} SparkDataFrame. +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two +#' numerical columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column @@ -159,8 +159,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @param relativeError The relative target precision to achieve (>= 0). If set to zero, #' the exact quantiles are computed, which could be very expensive. #' Note that values greater than 1 are accepted but give the same result as 1. -#' @return The approximate quantiles at the given probabilities. If the input is a single column name, -#' the output is a list of approximate quantiles in that column; If the input is +#' @return The approximate quantiles at the given probabilities. If the input is a single column +#' name, the output is a list of approximate quantiles in that column; If the input is #' multiple column names, the output should be a list, and each element in it is a list of #' numeric values which represents the approximate quantiles in corresponding column. #' diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 91483a4d23d9b..4b716995f2c46 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -625,7 +625,7 @@ appendPartitionLengths <- function(x, other) { x <- lapplyPartition(x, appendLength) other <- lapplyPartition(other, appendLength) } - list (x, other) + list(x, other) } # Perform zip or cartesian between elements from two RDDs in each partition @@ -657,7 +657,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[ (lengthOfKeys + 1) : (len - 1) ] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 03e7450147865..00789d815bba8 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -68,7 +68,7 @@ compute <- function(mode, partition, serializer, deserializer, key, } else { output <- computeFunc(partition, inputData) } - return (output) + return(output) } outputResult <- function(serializer, output, outputCon) { diff --git a/R/pkg/tests/fulltests/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R index 442bed509bb1d..c5d240f3e7344 100644 --- a/R/pkg/tests/fulltests/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -73,7 +73,7 @@ test_that("zipPartitions() on RDDs", { rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, - func = function(x, y, z) { list(list(x, y, z))} )) + func = function(x, y, z) { list(list(x, y, z))})) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) diff --git a/R/pkg/tests/fulltests/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R index 6ee1fceffd822..0c702ea897f7c 100644 --- a/R/pkg/tests/fulltests/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -698,14 +698,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - numPairsRdd <- map(rdd, function(x) { list (x, x) }) + numPairsRdd <- map(rdd, function(x) { list(x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) - numPairs <- lapply(nums, function(x) { list (x, x) }) + numPairs <- lapply(nums, function(x) { list(x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) - numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + numPairsRdd2 <- map(rdd2, function(x) { list(x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4e62be9b4d619..7f781f2f66a7f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -560,9 +560,9 @@ test_that("Collect DataFrame with complex types", { expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) expect_equal(names(ldf), c("c1", "c2", "c3")) - expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) - expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) - expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list(7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list(7.0, 8.0, 9.0))) # MapType schema <- structType(structField("name", "string"), @@ -1524,7 +1524,7 @@ test_that("column functions", { expect_equal(ncol(s), 1) expect_equal(nrow(s), 3) expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } # passing option @@ -2710,7 +2710,7 @@ test_that("freqItems() on a DataFrame", { input <- 1:1000 rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) - rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + rdf[input %% 3 == 0, ] <- c(1, "1", -1) df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) @@ -3064,7 +3064,7 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - df <- createDataFrame ( + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) expected <- collect(df) @@ -3135,7 +3135,7 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { actual <- df3Collect[order(df3Collect$a), ] expect_identical(actual$avg, expected$avg) - irisDF <- suppressWarnings(createDataFrame (iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) # Groups by `Sepal_Length` and computes the average for `Sepal_Width` df4 <- gapply( diff --git a/dev/lint-r.R b/dev/lint-r.R index 87ee36d5c9b68..a4261d266bbc0 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -26,8 +26,8 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { # Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. -if ("lintr" %in% row.names(installed.packages()) == FALSE) { - devtools::install_github("jimhester/lintr@a769c0b") +if ("lintr" %in% row.names(installed.packages()) == FALSE) { + devtools::install_github("jimhester/lintr@5431140") } library(lintr) From 3ca367083e196e6487207211e6c49d4bbfe31288 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 1 Oct 2017 10:49:22 -0700 Subject: [PATCH 002/712] [SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass ## What changes were proposed in this pull request? SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`. The new `withColumns` API is for internal use only now. ## How was this patch tested? Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API. Author: Liang-Chi Hsieh Closes #19229 from viirya/SPARK-22001. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 42 ++++++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 52 +++++++++++++++++++ 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1f36eced3d08f..4663f16b5f5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -223,20 +223,18 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var outputDF = dataset val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq - $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol) - outputDF = outputDF.withColumn(outputCol, - when(ic.isNull, surrogate) + when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) - .cast(inputType)) + .cast(inputType) } - outputDF.toDF() + dataset.withColumns($(outputCols), newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab0c4126bcbdd..f2a76a506eb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2083,22 +2083,40 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DataFrame = { + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + + /** + * Returns a new Dataset by adding columns or replacing the existing columns that has + * the same names. + */ + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { + require(colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + SchemaUtils.checkColumnNameDuplication( + colNames, + "in given column names", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } + + val columnMap = colNames.zip(cols).toMap + + val replacedAndExistingColumns = output.map { field => + columnMap.find { case (colName, _) => + resolver(field.name, colName) + } match { + case Some((colName: String, col: Column)) => col.as(colName) + case _ => Column(field) } - select(columns : _*) - } else { - select(Column("*"), col.as(colName)) } + + val newColumns = columnMap.filter { case (colName, col) => + !output.exists(f => resolver(f.name, colName)) + }.map { case (colName, col) => col.as(colName) } + + select(replacedAndExistingColumns ++ newColumns : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0e2f2e5a193e1..672deeac597f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } + test("withColumns") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert( + err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) + + val err2 = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err2.getMessage.contains("Found duplicate column(s)")) + } + + test("withColumns: case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1")) + + val err = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("Found duplicate column(s)")) + } + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) @@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("replace column using withColumns") { + val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), + Seq(df2("x") + 1, df2("y"), df2("y") + 1)) + checkAnswer( + df3.select("x", "newCol1", "newCol2"), + Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil) + } + test("drop column using drop") { val df = testData.drop("key") checkAnswer( From 405c0e99e7697bfa88aa4abc9a55ce5e043e48b1 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 2 Oct 2017 08:07:56 +0100 Subject: [PATCH 003/712] [SPARK-22173][WEB-UI] Table CSS style needs to be adjusted in History Page and in Executors Page. ## What changes were proposed in this pull request? There is a problem with table CSS style. 1. At present, table CSS style is too crowded, and the table width cannot adapt itself. 2. Table CSS style is different from job page, stage page, task page, master page, worker page, etc. The Spark web UI needs to be consistent. fix before: ![01](https://user-images.githubusercontent.com/26266482/31041261-c6766c3a-a5c4-11e7-97a7-96bd51ef12bd.png) ![02](https://user-images.githubusercontent.com/26266482/31041266-d75b6a32-a5c4-11e7-8071-e3bbbba39b80.png) ---------------------------------------------------------------------------------------------------------- fix after: ![1](https://user-images.githubusercontent.com/26266482/31041162-808a5a3e-a5c3-11e7-8d92-d763b500ce53.png) ![2](https://user-images.githubusercontent.com/26266482/31041166-86e583e0-a5c3-11e7-949c-11c370db9e27.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19397 from guoxiaolongzte/SPARK-22173. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- .../main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index af14717633409..6399dccc1676a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val content =
-
+
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
@@ -58,7 +58,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (allAppsSize > 0) { ++ - ++ +
++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d63381c78bc3b..7b2767f0be3cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
++ - ++ +
++ ++ ++ From 8fab7995d36c7bc4524393b20a4e524dbf6bbf62 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 2 Oct 2017 11:46:51 -0700 Subject: [PATCH 004/712] [SPARK-22167][R][BUILD] sparkr packaging issue allow zinc ## What changes were proposed in this pull request? When zinc is running the pwd might be in the root of the project. A quick solution to this is to not go a level up incase we are in the root rather than root/core/. If we are in the root everything works fine, if we are in core add a script which goes and runs the level up ## How was this patch tested? set -x in the SparkR install scripts. Author: Holden Karau Closes #19402 from holdenk/SPARK-22167-sparkr-packaging-issue-allow-zinc. --- R/install-dev.sh | 1 + core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/R/install-dev.sh b/R/install-dev.sh index d613552718307..9fbc999f2e805 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -28,6 +28,7 @@ set -o pipefail set -e +set -x FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" diff --git a/core/pom.xml b/core/pom.xml index 09669149d8123..54f7a34a6c37e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -499,7 +499,7 @@ - ..${file.separator}R${file.separator}install-dev${script.extension} + ${project.basedir}${file.separator}..${file.separator}R${file.separator}install-dev${script.extension} From e5431f2cfddc8e96194827a2123b92716c7a1467 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 2 Oct 2017 15:00:26 -0700 Subject: [PATCH 005/712] [SPARK-22158][SQL] convertMetastore should not ignore table property ## What changes were proposed in this pull request? From the beginning, convertMetastoreOrc ignores table properties and use an empty map instead. This PR fixes that. For the diff, please see [this](https://github.com/apache/spark/pull/19382/files?w=1). convertMetastoreParquet also ignore. ```scala val options = Map[String, String]() ``` - [SPARK-14070: HiveMetastoreCatalog.scala](https://github.com/apache/spark/pull/11891/files#diff-ee66e11b56c21364760a5ed2b783f863R650) - [Master branch: HiveStrategies.scala](https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala#L197 ) ## How was this patch tested? Pass the Jenkins with an updated test suite. Author: Dongjoon Hyun Closes #19382 from dongjoon-hyun/SPARK-22158. --- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 54 ++++++++++++++++--- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 805b3171cdaab..3592b8f4846d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -189,12 +189,12 @@ case class RelationConversions( private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { - val options = Map(ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = Map[String, String]() + val options = relation.tableMeta.storage.properties sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 668da5fb47323..02e26bbe876a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -23,6 +23,8 @@ import java.net.URI import scala.language.existentials import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop.ParquetFileReader import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException @@ -32,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAl import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -1455,12 +1458,8 @@ class HiveDDLSuite sql("INSERT INTO t SELECT 1") checkAnswer(spark.table("t"), Row(1)) // Check if this is compressed as ZLIB. - val maybeOrcFile = path.listFiles().find(!_.getName.endsWith(".crc")) - assert(maybeOrcFile.isDefined) - val orcFilePath = maybeOrcFile.get.toPath.toString - val expectedCompressionKind = - OrcFileOperator.getFileReader(orcFilePath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeOrcFile, "orc", "ZLIB") sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2") val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2")) @@ -2009,4 +2008,47 @@ class HiveDDLSuite } } } + + private def assertCompression(maybeFile: Option[File], format: String, compression: String) = { + assert(maybeFile.isDefined) + + val actualCompression = format match { + case "orc" => + OrcFileOperator.getFileReader(maybeFile.get.toPath.toString).get.getCompression.name + + case "parquet" => + val footer = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, new Path(maybeFile.get.getPath), NO_FILTER) + footer.getBlocks.get(0).getColumns.get(0).getCodec.toString + } + + assert(compression === actualCompression) + } + + Seq(("orc", "ZLIB"), ("parquet", "GZIP")).foreach { case (fileFormat, compression) => + test(s"SPARK-22158 convertMetastore should not ignore table property - $fileFormat") { + withSQLConf(CONVERT_METASTORE_ORC.key -> "true", CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) USING hive + |OPTIONS(fileFormat '$fileFormat', compression '$compression') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains(fileFormat)) + assert(table.storage.properties.get("compression") == Some(compression)) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeFile, fileFormat, compression) + } + } + } + } + } } From 4329eb2e73181819bb712f57ca9c7feac0d640ea Mon Sep 17 00:00:00 2001 From: Gene Pang Date: Mon, 2 Oct 2017 15:09:11 -0700 Subject: [PATCH 006/712] [SPARK-16944][Mesos] Improve data locality when launching new executors when dynamic allocation is enabled ## What changes were proposed in this pull request? Improve the Spark-Mesos coarse-grained scheduler to consider the preferred locations when dynamic allocation is enabled. ## How was this patch tested? Added a unittest, and performed manual testing on AWS. Author: Gene Pang Closes #18098 from gpang/mesos_data_locality. --- .../spark/internal/config/package.scala | 4 ++ .../spark/scheduler/TaskSetManager.scala | 6 +- .../MesosCoarseGrainedSchedulerBackend.scala | 52 ++++++++++++++-- ...osCoarseGrainedSchedulerBackendSuite.scala | 62 +++++++++++++++++++ .../spark/scheduler/cluster/mesos/Utils.scala | 6 ++ 5 files changed, 123 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 44a2815b81a73..d85b6a0200b8d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3s") + private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index bb867416a4fac..3bdede6743d1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -980,7 +980,7 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3s") + val defaultWait = conf.get(config.LOCALITY_WAIT) val localityWaitKey = level match { case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" @@ -989,7 +989,7 @@ private[spark] class TaskSetManager( } if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait.toString) } else { 0L } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 26699873145b4..80c0a041b7322 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -99,6 +99,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private var totalCoresAcquired = 0 private var totalGpusAcquired = 0 + // The amount of time to wait for locality scheduling + private val localityWait = conf.get(config.LOCALITY_WAIT) + // The start of the waiting, for data local scheduling + private var localityWaitStartTime = System.currentTimeMillis() + // If true, the scheduler is in the process of launching executors to reach the requested + // executor limit + private var launchingExecutors = false + // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because // we need to maintain e.g. failure state and connection state. @@ -311,6 +319,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( return } + if (numExecutors >= executorLimit) { + logDebug("Executor limit reached. numExecutors: " + numExecutors + + " executorLimit: " + executorLimit) + offers.asScala.map(_.getId).foreach(d.declineOffer) + launchingExecutors = false + return + } else { + if (!launchingExecutors) { + launchingExecutors = true + localityWaitStartTime = System.currentTimeMillis() + } + } + logDebug(s"Received ${offers.size} resource offers.") val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => @@ -413,7 +434,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerId = offer.getId.getValue val resources = remainingResources(offerId) - if (canLaunchTask(slaveId, resources)) { + if (canLaunchTask(slaveId, offer.getHostname, resources)) { // Create a task launchTasks = true val taskId = newMesosTaskId() @@ -477,7 +498,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } - private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + private def canLaunchTask(slaveId: String, offerHostname: String, + resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt val cpus = executorCores(offerCPUs) @@ -489,9 +511,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus <= offerCPUs && cpus + totalCoresAcquired <= maxCores && mem <= offerMem && - numExecutors() < executorLimit && + numExecutors < executorLimit && slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && - meetsPortRequirements + meetsPortRequirements && + satisfiesLocality(offerHostname) } private def executorCores(offerCPUs: Int): Int = { @@ -500,6 +523,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( ) } + private def satisfiesLocality(offerHostname: String): Boolean = { + if (!Utils.isDynamicAllocationEnabled(conf) || hostToLocalTaskCount.isEmpty) { + return true + } + + // Check the locality information + val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet + val allDesiredHosts = hostToLocalTaskCount.keys.toSet + // Try to match locality for hosts which do not have executors yet, to potentially + // increase coverage. + val remainingHosts = allDesiredHosts -- currentHosts + if (!remainingHosts.contains(offerHostname) && + (System.currentTimeMillis() - localityWaitStartTime <= localityWait)) { + logDebug("Skipping host and waiting for locality. host: " + offerHostname) + return false + } + return true + } + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue val slaveId = status.getSlaveId.getValue @@ -646,6 +688,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) executorLimitOption = Some(requestedTotal) + // Update the locality wait start time to continue trying for locality. + localityWaitStartTime = System.currentTimeMillis() true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f6bae01c3af59..6c40792112f49 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -604,6 +604,55 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(backend.isReady) } + test("supports data locality with dynamic allocation") { + setBackend(Map( + "spark.dynamicAllocation.enabled" -> "true", + "spark.dynamicAllocation.testing" -> "true", + "spark.locality.wait" -> "1s")) + + assert(backend.getExecutorIds().isEmpty) + + backend.requestTotalExecutors(2, 2, Map("hosts10" -> 1, "hosts11" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(1, false) + offerResourcesAndVerify(2, false) + + // Offer local resource + offerResourcesAndVerify(10, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(1, true) + + // Update total executors + backend.requestTotalExecutors(3, 3, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Update total executors + backend.requestTotalExecutors(4, 4, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, + "hosts13" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Offer local resource + offerResourcesAndVerify(13, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(2, true) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { @@ -631,6 +680,19 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.resourceOffers(driver, mesosOffers.asJava) } + private def offerResourcesAndVerify(id: Int, expectAccept: Boolean): Unit = { + offerResources(List(Resources(backend.executorMemory(sc), 1)), id) + if (expectAccept) { + val numExecutors = backend.getExecutorIds().size + val launchedTasks = verifyTaskLaunched(driver, s"o$id") + assert(s"s$id" == launchedTasks.head.getSlaveId.getValue) + registerMockExecutor(launchedTasks.head.getTaskId.getValue, s"s$id", 1) + assert(backend.getExecutorIds().size == numExecutors + 1) + } else { + verifyTaskNotLaunched(driver, s"o$id") + } + } + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 2a67cbc913ffe..833db0c1ff334 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -84,6 +84,12 @@ object Utils { captor.getValue.asScala.toList } + def verifyTaskNotLaunched(driver: SchedulerDriver, offerId: String): Unit = { + verify(driver, times(0)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + Matchers.any(classOf[java.util.Collection[TaskInfo]])) + } + def createOfferId(offerId: String): OfferID = { OfferID.newBuilder().setValue(offerId).build() } From fa225da7463e384529da14706e44f4a09772e5c1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 2 Oct 2017 15:25:33 -0700 Subject: [PATCH 007/712] [SPARK-22176][SQL] Fix overflow issue in Dataset.show ## What changes were proposed in this pull request? This pr fixed an overflow issue below in `Dataset.show`: ``` scala> Seq((1, 2), (3, 4)).toDF("a", "b").show(Int.MaxValue) org.apache.spark.sql.AnalysisException: The limit expression must be equal to or greater than 0, but got -2147483648;; GlobalLimit -2147483648 +- LocalLimit -2147483648 +- Project [_1#27218 AS a#27221, _2#27219 AS b#27222] +- LocalRelation [_1#27218, _2#27219] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:89) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.org$apache$spark$sql$catalyst$analysis$CheckAnalysis$$checkLimitClause(CheckAnalysis.scala:70) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:234) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:80) at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127) ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #19401 from maropu/MaxValueInShowString. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f2a76a506eb6f..b70dfc05330f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,7 +237,7 @@ class Dataset[T] private[sql]( */ private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0) + val numRows = _numRows.max(0).min(Int.MaxValue - 1) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 672deeac597f1..dd8f54b690f64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1045,6 +1045,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(Int.MaxValue)") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val expectedAnswer = """+---+---+ + || a| b| + |+---+---+ + || 1| 2| + || 3| 4| + |+---+---+ + |""".stripMargin + assert(df.showString(Int.MaxValue) === expectedAnswer) + } + test("showString(0), vertical = true") { val expectedAnswer = "(0 rows)\n" assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) From 4c5158eec9101ef105274df6b488e292a56156a2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Oct 2017 12:38:13 -0700 Subject: [PATCH 008/712] [SPARK-21644][SQL] LocalLimit.maxRows is defined incorrectly ## What changes were proposed in this pull request? The definition of `maxRows` in `LocalLimit` operator was simply wrong. This patch introduces a new `maxRowsPerPartition` method and uses that in pruning. The patch also adds more documentation on why we need local limit vs global limit. Note that this previously has never been a bug because the way the code is structured, but future use of the maxRows could lead to bugs. ## How was this patch tested? Should be covered by existing test cases. Closes #18851 Author: gatorsmile Author: Reynold Xin Closes #19393 from gatorsmile/pr-18851. --- .../sql/catalyst/optimizer/Optimizer.scala | 29 ++++++----- .../catalyst/plans/logical/LogicalPlan.scala | 5 ++ .../plans/logical/basicLogicalOperators.scala | 49 ++++++++++++++++++- .../execution/basicPhysicalOperators.scala | 3 ++ 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b9fa39d6dad4c..bc2d4a824cb49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -305,13 +305,20 @@ object LimitPushDown extends Rule[LogicalPlan] { } } - private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { - (limitExp, plan.maxRows) match { - case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + private def maybePushLocalLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRowsPerPartition) match { + case (IntegerLiteral(newLimit), Some(childMaxRows)) if newLimit < childMaxRows => + // If the child has a cap on max rows per partition and the cap is larger than + // the new limit, put a new LocalLimit there. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + // If the child has no cap, put the new LocalLimit. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) - case _ => plan + + case _ => + // Otherwise, don't put a new LocalLimit. + plan } } @@ -323,7 +330,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to // pushdown Limit. case LocalLimit(exp, Union(children)) => - LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side // because we need to ensure that rows from the limited side still have an opportunity to match @@ -335,19 +342,19 @@ object LimitPushDown extends Rule[LogicalPlan] { // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { - case RightOuter => join.copy(right = maybePushLimit(exp, right)) - case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLimit(exp, left)) + join.copy(left = maybePushLocalLimit(exp, left)) } else { - join.copy(right = maybePushLimit(exp, right)) + join.copy(right = maybePushLocalLimit(exp, right)) } case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) } case _ => join diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 68aae720e026a..14188829db2af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -97,6 +97,11 @@ abstract class LogicalPlan */ def maxRows: Option[Long] = None + /** + * Returns the maximum number of rows this plan may compute on each partition. + */ + def maxRowsPerPartition: Option[Long] = maxRows + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f443cd5a69de3..80243d3d356ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -191,6 +191,9 @@ object Union { } } +/** + * Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + */ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { @@ -200,6 +203,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } + /** + * Note the definition has assumption about how union is implemented physically. + */ + override def maxRowsPerPartition: Option[Long] = { + if (children.exists(_.maxRowsPerPartition.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRowsPerPartition).sum) + } + } + // updating nullability to make all the children consistent override def output: Seq[Attribute] = children.map(_.output).transpose.map(attrs => @@ -669,6 +683,27 @@ case class Pivot( } } +/** + * A constructor for creating a logical limit, which is split into two separate logical nodes: + * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. + * + * This muds the water for clean logical/physical separation, and is done for better limit pushdown. + * In distributed query processing, a non-terminal global limit is actually an expensive operation + * because it requires coordination (in Spark this is done using a shuffle). + * + * In most cases when we want to push down limit, it is often better to only push some partition + * local limit. Consider the following: + * + * GlobalLimit(Union(A, B)) + * + * It is better to do + * GlobalLimit(Union(LocalLimit(A), LocalLimit(B))) + * + * than + * Union(GlobalLimit(A), GlobalLimit(B)). + * + * So we introduced LocalLimit and GlobalLimit in the logical plan node for limit pushdown. + */ object Limit { def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) @@ -682,6 +717,11 @@ object Limit { } } +/** + * A global (coordinated) limit. This operator can emit at most `limitExpr` number in total. + * + * See [[Limit]] for more information. + */ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { @@ -692,9 +732,16 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } } +/** + * A partition-local (non-coordinated) limit. This operator can emit at most `limitExpr` number + * of tuples on each physical partition. + * + * See [[Limit]] for more information. + */ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def maxRows: Option[Long] = { + + override def maxRowsPerPartition: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 8389e2f3d5be9..63cd1691f4cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -554,6 +554,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) /** * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * If we change how this is implemented physically, we'd need to update + * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. */ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { override def output: Seq[Attribute] = From e65b6b7ca1a7cff1b91ad2262bb7941e6bf057cd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 12:40:22 -0700 Subject: [PATCH 009/712] [SPARK-22178][SQL] Refresh Persistent Views by REFRESH TABLE Command ## What changes were proposed in this pull request? The underlying tables of persistent views are not refreshed when users issue the REFRESH TABLE command against the persistent views. ## How was this patch tested? Added a test case Author: gatorsmile Closes #19405 from gatorsmile/refreshView. --- .../apache/spark/sql/internal/CatalogImpl.scala | 15 +++++++++++---- .../spark/sql/hive/HiveMetadataCacheSuite.scala | 14 +++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 142b005850a49..fdd25330c5e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -474,13 +474,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def refreshTable(tableName: String): Unit = { val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - // Temp tables: refresh (or invalidate) any metadata/data cached in the plan recursively. - // Non-temp tables: refresh the metadata cache. - sessionCatalog.refreshTable(tableIdent) + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val table = sparkSession.table(tableIdent) + + if (tableMetadata.tableType == CatalogTableType.VIEW) { + // Temp or persistent views: refresh (or invalidate) any metadata/data cached + // in the plan recursively. + table.queryExecution.analyzed.foreach(_.refresh()) + } else { + // Non-temp tables: refresh the metadata cache. + sessionCatalog.refreshTable(tableIdent) + } // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val table = sparkSession.table(tableIdent) if (isCached(table)) { // Uncache the logicalPlan. sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 0c28a1b609bb8..e71aba72c31fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -31,14 +31,22 @@ import org.apache.spark.sql.test.SQLTestUtils class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-16337 temporary view refresh") { - withTempView("view_refresh") { + checkRefreshView(isTemp = true) + } + + test("view refresh") { + checkRefreshView(isTemp = false) + } + + private def checkRefreshView(isTemp: Boolean) { + withView("view_refresh") { withTable("view_table") { // Create a Parquet directory spark.range(start = 0, end = 100, step = 1, numPartitions = 3) .write.saveAsTable("view_table") - // Read the table in - spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh") + val temp = if (isTemp) "TEMPORARY" else "" + spark.sql(s"CREATE $temp VIEW view_refresh AS SELECT * FROM view_table WHERE id > -1") assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) // Delete a file using the Hadoop file system interface since the path returned by From e36ec38d89472df0dfe12222b6af54cd6eea8e98 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Tue, 3 Oct 2017 16:53:32 -0700 Subject: [PATCH 010/712] [SPARK-20466][CORE] HadoopRDD#addLocalConfiguration throws NPE ## What changes were proposed in this pull request? Fix for SPARK-20466, full description of the issue in the JIRA. To summarize, `HadoopRDD` uses a metadata cache to cache `JobConf` objects. The cache uses soft-references, which means the JVM can delete entries from the cache whenever there is GC pressure. `HadoopRDD#getJobConf` had a bug where it would check if the cache contained the `JobConf`, if it did it would get the `JobConf` from the cache and return it. This doesn't work when soft-references are used as the JVM can delete the entry between the existence check and the get call. ## How was this patch tested? Haven't thought of a good way to test this yet given the issue only occurs sometimes, and happens during high GC pressure. Was thinking of using mocks to verify `#getJobConf` is doing the right thing. I deleted the method `HadoopRDD#containsCachedMetadata` so that we don't hit this issue again. Author: Sahil Takiar Closes #19413 from sahilTakiar/master. --- .../org/apache/spark/rdd/HadoopRDD.scala | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 76ea8b86c53d2..23b344230e490 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -157,20 +157,25 @@ class HadoopRDD[K, V]( if (conf.isInstanceOf[JobConf]) { logDebug("Re-using user-broadcasted JobConf") conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - logDebug("Re-using cached JobConf") - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in + // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } } @@ -360,8 +365,6 @@ private[spark] object HadoopRDD extends Logging { */ def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) From 5f694334534e4425fb9e8abf5b7e3e5efdfcef50 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 3 Oct 2017 21:27:58 -0700 Subject: [PATCH 011/712] [SPARK-22171][SQL] Describe Table Extended Failed when Table Owner is Empty ## What changes were proposed in this pull request? Users could hit `java.lang.NullPointerException` when the tables were created by Hive and the table's owner is `null` that are got from Hive metastore. `DESC EXTENDED` failed with the error: > SQLExecutionException: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.catalyst.catalog.CatalogTable.toLinkedHashMap(interface.scala:300) at org.apache.spark.sql.execution.command.DescribeTableCommand.describeFormattedTableInfo(tables.scala:565) at org.apache.spark.sql.execution.command.DescribeTableCommand.run(tables.scala:543) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:66) at ## How was this patch tested? Added a unit test case Author: gatorsmile Closes #19395 from gatorsmile/desc. --- .../sql/catalyst/catalog/interface.scala | 2 +- .../sql/catalyst/analysis/CatalogSuite.scala | 37 +++++++++++++++++++ .../sql/hive/client/HiveClientImpl.scala | 2 +- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1965144e81197..fe2af910a0ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -307,7 +307,7 @@ case class CatalogTable( identifier.database.foreach(map.put("Database", _)) map.put("Table", identifier.table) - if (owner.nonEmpty) map.put("Owner", owner) + if (owner != null && owner.nonEmpty) map.put("Owner", owner) map.put("Created Time", new Date(createTime).toString) map.put("Last Access", new Date(lastAccessTime).toString) map.put("Created By", "Spark " + createVersion) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala new file mode 100644 index 0000000000000..d670053ba1b5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.types.StructType + + +class CatalogSuite extends AnalysisTest { + + test("desc table when owner is set to null") { + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + owner = null, + schema = new StructType().add("col1", "int").add("col2", "string"), + provider = Some("parquet")) + table.toLinkedHashMap + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c4e48c9360db7..66165c7228bca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -461,7 +461,7 @@ private[hive] class HiveClientImpl( // in table properties. This means, if we have bucket spec in both hive metastore and // table properties, we will trust the one in table properties. bucketSpec = bucketSpec, - owner = h.getOwner, + owner = Option(h.getOwner).getOrElse(""), createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( From 3099c574c56cab86c3fcf759864f89151643f837 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Oct 2017 21:42:51 -0700 Subject: [PATCH 012/712] [SPARK-22136][SS] Implement stream-stream outer joins. ## What changes were proposed in this pull request? Allow one-sided outer joins between two streams when a watermark is defined. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19327 from joseph-torres/outerjoin. --- .../analysis/StreamingJoinHelper.scala | 286 +++++++++++++++++ .../UnsupportedOperationChecker.scala | 53 +++- .../analysis/StreamingJoinHelperSuite.scala | 140 ++++++++ .../analysis/UnsupportedOperationsSuite.scala | 108 ++++++- .../StreamingSymmetricHashJoinExec.scala | 152 +++++++-- .../StreamingSymmetricHashJoinHelper.scala | 241 +------------- .../state/SymmetricHashJoinStateManager.scala | 200 +++++++++--- .../SymmetricHashJoinStateManagerSuite.scala | 6 +- .../sql/streaming/StreamingJoinSuite.scala | 298 +++++++++++------- 9 files changed, 1051 insertions(+), 433 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala new file mode 100644 index 0000000000000..072dc954879ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in SQL for more details. + */ +object StreamingJoinHelper extends PredicateHelper with Logging { + + /** + * Check the provided logical plan to see if its join keys contain a watermark attribute. + * + * Will return false if the plan is not an equijoin. + * @param plan the logical plan to check + */ + def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { + plan match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + (leftKeys ++ rightKeys).exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + case _ => false + } + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * This function is supposed to make best-effort attempt to get the state watermark. If there is + * any error, it will return None. + * + * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds, is possible. + */ + def getStateValueWatermark( + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + + // If condition or event time watermark is not provided, then cannot calculate state watermark + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + // If there is not watermark attribute, then cannot define state watermark + if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatermarkFromLessThenPredicate( + l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + + // The generated the state watermark cleanup expression is inclusive of the state watermark. + // If state watermark is W, all state where timestamp <= W will be cleaned up. + // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean + // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract the state value watermark (milliseconds) from the condition + * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatermarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + eventWatermark: Option[Long]): Option[Long] = { + + val attributesInCondition = AttributeSet( + leftExpr.collect { case a: AttributeReference => a } ++ + rightExpr.collect { case a: AttributeReference => a } + ) + if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || + attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { + // If more than attributes present in condition from one side, then it cannot be solved + return None + } + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a @ AttributeReference(_, _, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a @ AttributeReference(_, _, _, metadata) + if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => + Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some((Double2double(constraintValue) / 1000.0).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in microseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { + expr match { + case Add(left, right) => + collect(left, negate) ++ collect(right, negate) + case Subtract(left, right) => + collect(left, negate) ++ collect(right, !negate) + case TimeAdd(left, right, _) => + collect(left, negate) ++ collect(right, negate) + case TimeSub(left, right, _) => + collect(left, negate) ++ collect(right, !negate) + case UnaryMinus(child) => + collect(child, !negate) + case CheckOverflow(child, _) => + collect(child, negate) + case Cast(child, dataType, _) => + dataType match { + case _: NumericType | _: TimestampType => collect(child, negate) + case _ => + invalid = true + Seq.empty + } + case a: AttributeReference => + val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a + Seq(negateIfNeeded(castedRef, negate)) + case lit: Literal => + // If literal of type calendar interval, then explicitly convert to millis + // Convert other number like literal to doubles representing millis (by x1000) + val castedLit = lit.dataType match { + case CalendarIntervalType => + val calendarInterval = lit.value.asInstanceOf[CalendarInterval] + if (calendarInterval.months > 0) { + invalid = true + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom " + + s"as imprecise intervals like months and years cannot be used for" + + s"watermark calculation. Use interval in terms of day instead.") + Literal(0.0) + } else { + Literal(calendarInterval.microseconds.toDouble) + } + case DoubleType => + Multiply(lit, Literal(1000000.0)) + case _: NumericType => + Multiply(Cast(lit, DoubleType), Literal(1000000.0)) + case _: TimestampType => + Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) + } + Seq(negateIfNeeded(castedLit, negate)) + case a @ _ => + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") + invalid = true + Seq.empty + } + } + + val terms = collect(exprToCollectFrom, negate = false) + if (!invalid) terms else Seq.empty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d1d705691b076..dee6fbe9d1514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -217,7 +218,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => + case Join(left, right, joinType, condition) => joinType match { @@ -233,16 +234,52 @@ object UnsupportedOperationChecker { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftOuter | LeftSemi | LeftAnti => + case LeftSemi | LeftAnti => if (right.isStreaming) { - throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + - "on the right is not supported") + throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + "on the right are not supported") } + // We support streaming left outer joins with static on the right always, and with + // stream on both sides under the appropriate conditions. + case LeftOuter => + if (!left.isStreaming && right.isStreaming) { + throwError("Left outer join with a streaming DataFrame/Dataset " + + "on the right and a static DataFrame/Dataset on the left is not supported") + } else if (left.isStreaming && right.isStreaming) { + val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + left.outputSet, right.outputSet, condition, Some(1000000)).isDefined + + if (!watermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } + } + + // We support streaming right outer joins with static on the left always, and with + // stream on both sides under the appropriate conditions. case RightOuter => - if (left.isStreaming) { - throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + - "not supported") + if (left.isStreaming && !right.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left and " + + "a static DataFrame/DataSet on the right not supported") + } else if (left.isStreaming && right.isStreaming) { + val isWatermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + // Check if the nullable side has a watermark, and there's a range condition which + // implies a state value watermark on the first side. + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + right.outputSet, left.outputSet, condition, Some(1000000)).isDefined + + if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } } case NaturalJoin(_) | UsingJoin(_, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala new file mode 100644 index 0000000000000..8cf41a02320d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter, LeafNode, LocalRelation} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType} + +class StreamingJoinHelperSuite extends AnalysisTest { + + test("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + case class DummyLeafNode() extends LeafNode { + override def output: Seq[Attribute] = + attributesToFindConstraintFor ++ attributesWithWatermark + } + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + CatalystSqlParser.parseExpression(str), + DummyLeafNode()) + val optimized = SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan)) + optimized.asInstanceOf[Filter].condition + } + StreamingJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + + // Test non-positive results + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 10") === Some(0)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 100") === Some(-90000)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 11f48a39c1e25..e5057c451d5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBinaryOperationInStreamingPlan( "left outer join", _.join(_, joinType = LeftOuter), - streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + streamStreamSupported = false, + expectedMsg = "outer join") + + // Left outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left semi joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( "right outer join", _.join(_, joinType = RightOuter), + streamBatchSupported = false, streamStreamSupported = false, - streamBatchSupported = false) + expectedMsg = "outer join") + + // Right outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Right outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 < attribute)) + }, + OutputMode.Append()) + + // Right outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 > attribute)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 44f1fa58599d2..9bd2127a28ff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -146,7 +146,14 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } - require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType") + private def throwBadJoinTypeException(): Nothing = { + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $joinType as the JoinType") + } + + require( + joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + s"${getClass.getSimpleName} should not take $joinType as the JoinType") require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -157,11 +164,18 @@ case class StreamingSymmetricHashJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case _ => throwBadJoinTypeException() + } override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) + case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -207,31 +221,108 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(inputRow).withRight(matchedRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(matchedRow).withRight(inputRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + // We need to save the time that the inner join output iterator completes, since outer join + // output counts as both update and removal time. + var innerOutputCompletionTimeNs: Long = 0 + def onInnerOutputCompletion = { + innerOutputCompletionTimeNs = System.nanoTime + } + val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) + + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists( + rightValue => { + outputFilterFunction( + joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + }) + } + + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists( + leftValue => { + outputFilterFunction( + joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + }) + } + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state, and that key + // has any value which satisfies the filter function when joined. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithRightSideState(pair)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + filteredInnerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithLeftSideState(pair)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + filteredInnerOutputIter ++ outerOutputIter + case _ => throwBadJoinTypeException() + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { + // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // Processing time between inner output completion and here comes from the outer portion of a + // join, and thus counts as removal time as we remove old state from one side while iterating. + if (innerOutputCompletionTimeNs != 0) { + allRemovalsTimeMs += + math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + } + allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Remove any remaining state rows which aren't needed because they're below the watermark. + // + // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For outer joins, we have already removed unnecessary state rows from the outer side + // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we + // have to remove unnecessary state rows from the other side (e.g., right side for the left + // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal + // needs to be done greedily by immediately consuming the returned iterator. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throwBadJoinTypeException() + } + while (cleanupIter.hasNext) { + cleanupIter.next() + } } // Commit all state changes and update state store metrics @@ -251,7 +342,8 @@ case class StreamingSymmetricHashJoinExec( } } - CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterWithMetrics, onOutputCompletion) } /** @@ -324,14 +416,32 @@ case class StreamingSymmetricHashJoinExec( } } - /** Remove old buffered state rows using watermarks for state keys and values */ - def removeOldState(): Unit = { + /** + * Get an iterator over the values stored in this joiner's state manager for the given key. + * + * Should not be interleaved with mutations. + */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + joinStateManager.get(key) + } + + /** + * Builds an iterator over old state key-value pairs, removing them lazily as they're produced. + * + * @note This iterator must be consumed fully before any other operations are made + * against this joiner's join state manager. For efficiency reasons, the intermediate states of + * the iterator leave the state manager in an undefined state. + * + * We do this to avoid requiring either two passes or full materialization when + * processing the rows for outer join. + */ + def removeOldState(): Iterator[UnsafeRowPair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) case Some(JoinStateValueWatermarkPredicate(expr)) => joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc) - case _ => + case _ => Iterator.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index e50274a1baba1..64c7189f72ac3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression @@ -34,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details. */ -object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { +object StreamingSymmetricHashJoinHelper extends Logging { sealed trait JoinSide case object LeftSide extends JoinSide { override def toString(): String = "left" } @@ -111,7 +112,7 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs - val stateValueWatermark = getStateValueWatermark( + val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, @@ -132,242 +133,6 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) } - /** - * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) - * given the join condition and the event time watermark. This is how it works. - * - The condition is split into conjunctive predicates, and we find the predicates of the - * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). - * - We canoncalize the predicate and solve it with the event time watermark value to find the - * value of the state watermark. - * This function is supposed to make best-effort attempt to get the state watermark. If there is - * any error, it will return None. - * - * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark - * is to be calculated - * @param attributesWithEventWatermark attributes of the other side which has a watermark column - * @param joinCondition join condition - * @param eventWatermark watermark defined on the input event data - * @return state value watermark in milliseconds, is possible. - */ - def getStateValueWatermark( - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - joinCondition: Option[Expression], - eventWatermark: Option[Long]): Option[Long] = { - - // If condition or event time watermark is not provided, then cannot calculate state watermark - if (joinCondition.isEmpty || eventWatermark.isEmpty) return None - - // If there is not watermark attribute, then cannot define state watermark - if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None - - def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { - try { - getStateWatermarkFromLessThenPredicate( - l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) - } catch { - case NonFatal(e) => - logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) - None - } - } - - val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => - - // The generated the state watermark cleanup expression is inclusive of the state watermark. - // If state watermark is W, all state where timestamp <= W will be cleaned up. - // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean - // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. - val stateWatermark = predicate match { - case LessThan(l, r) => getStateWatermarkSafely(l, r) - case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) - case GreaterThan(l, r) => getStateWatermarkSafely(r, l) - case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) - case _ => None - } - if (stateWatermark.nonEmpty) { - logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") - } - stateWatermark - } - allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) - } - - /** - * Extract the state value watermark (milliseconds) from the condition - * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for - * leftTime using the watermark on the rightTime. Example: - * - * Input: rightTime-with-watermark + c1 < leftTime + c2 - * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 - * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime - * With watermark value: watermark-value + c1 + (-c2) < leftTime - */ - private def getStateWatermarkFromLessThenPredicate( - leftExpr: Expression, - rightExpr: Expression, - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - eventWatermark: Option[Long]): Option[Long] = { - - val attributesInCondition = AttributeSet( - leftExpr.collect { case a: AttributeReference => a } ++ - rightExpr.collect { case a: AttributeReference => a } - ) - if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || - attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { - // If more than attributes present in condition from one side, then it cannot be solved - return None - } - - def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { - e.collectLeaves().collectFirst { - case a @ AttributeReference(_, TimestampType, _, _) - if attributesToFindStateWatermarkFor.contains(a) => a - }.nonEmpty - } - - // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 - val allOnLeftExpr = Subtract(leftExpr, rightExpr) - logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") - - // Canonicalization step 2: extract commutative terms - // rightTime-with-watermark, c1, -leftTime, -c2 - val terms = ExpressionSet(collectTerms(allOnLeftExpr)) - logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) - - - - // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor - val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) - - // Verify there is only one correct constraint term and of the correct type - if (constraintTerms.size > 1) { - logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - if (constraintTerms.isEmpty) { - logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { - // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` - // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. - // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark - // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about - // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, - // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. - return None - } - - // Replace watermark attribute with watermark value, and generate the resolved expression - // from the other terms. That is, - // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) - logDebug(s"Constraint term from join condition:\t$constraintTerm") - val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => - term.transform { - case a @ AttributeReference(_, TimestampType, _, metadata) - if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => - Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) - } - }.reduceLeft(Add) - - // Calculate the constraint value - logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") - val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] - Some((Double2double(constraintValue) / 1000.0).toLong) - } - - /** - * Collect all the terms present in an expression after converting it into the form - * a + b + c + d where each term be either an attribute or a literal casted to long, - * optionally wrapped in a unary minus. - */ - private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { - var invalid = false - - /** Wrap a term with UnaryMinus if its needs to be negated. */ - def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { - if (minus) UnaryMinus(expr) else expr - } - - /** - * Recursively split the expression into its leaf terms contains attributes or literals. - * Returns terms only of the forms: - * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), - * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) - * Multiply(Literal), UnaryMinus(Multiply(Literal)) - * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) - * - * Note: - * - If term needs to be negated for making it a commutative term, - * then it will be wrapped in UnaryMinus(...) - * - Each terms will be representing timestamp value or time interval in microseconds, - * typed as doubles. - */ - def collect(expr: Expression, negate: Boolean): Seq[Expression] = { - expr match { - case Add(left, right) => - collect(left, negate) ++ collect(right, negate) - case Subtract(left, right) => - collect(left, negate) ++ collect(right, !negate) - case TimeAdd(left, right, _) => - collect(left, negate) ++ collect(right, negate) - case TimeSub(left, right, _) => - collect(left, negate) ++ collect(right, !negate) - case UnaryMinus(child) => - collect(child, !negate) - case CheckOverflow(child, _) => - collect(child, negate) - case Cast(child, dataType, _) => - dataType match { - case _: NumericType | _: TimestampType => collect(child, negate) - case _ => - invalid = true - Seq.empty - } - case a: AttributeReference => - val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a - Seq(negateIfNeeded(castedRef, negate)) - case lit: Literal => - // If literal of type calendar interval, then explicitly convert to millis - // Convert other number like literal to doubles representing millis (by x1000) - val castedLit = lit.dataType match { - case CalendarIntervalType => - val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { - invalid = true - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom " + - s"as imprecise intervals like months and years cannot be used for" + - s"watermark calculation. Use interval in terms of day instead.") - Literal(0.0) - } else { - Literal(calendarInterval.microseconds.toDouble) - } - case DoubleType => - Multiply(lit, Literal(1000000.0)) - case _: NumericType => - Multiply(Cast(lit, DoubleType), Literal(1000000.0)) - case _: TimestampType => - Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) - } - Seq(negateIfNeeded(castedLit, negate)) - case a @ _ => - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") - invalid = true - Seq.empty - } - } - - val terms = collect(exprToCollectFrom, negate = false) - if (!invalid) terms else Seq.empty - } - /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 37648710dfc2a..d256fb578d921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager( /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues) + keyWithIndexToValue.getAll(key, numValues).map(_.value) } /** Append a new value to the key */ @@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager( } /** - * Remove using a predicate on keys. See class docs for more context and implement details. + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value) pairs satisfying condition(key), where the + * underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedPair = new UnsafeRowPair() + + private def getAndRemoveValue() = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + } + + override def getNext(): UnsafeRowPair = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null } + + override def close: Unit = {} } } /** - * Remove using a predicate on values. See class docs for more context and implementation details. + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, value) pairs such that value + * satisfies the predicate, where producing an element removes the value from the state store + * and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() - var numValues: Long = keyToNumValue.numValue - var index: Long = 0L - var valueRemoved: Boolean = false - var valueForIndex: UnsafeRow = null + private val allKeyToNumValues = keyToNumValues.iterator - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } } - if (condition(valueForIndex)) { - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) - keyWithIndexToValue.put(key, index, valueAtMaxIndex) - keyWithIndexToValue.remove(key, numValues - 1) - valueForIndex = valueAtMaxIndex + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): UnsafeRow = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(currentValue)) { + return currentValue + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue } else { - keyWithIndexToValue.remove(key, 0) - valueForIndex = null + // Should be unreachable, but in any case means a value couldn't be found. + return null } - numValues -= 1 - valueRemoved = true - } else { - valueForIndex = null - index += 1 } + + // We tried and failed to find the next value. + return null } - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(key, numValues) + + override def getNext(): UnsafeRowPair = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) } else { - keyToNumValues.remove(key) + keyWithIndexToValue.remove(currentKey, 0) } + numValues -= 1 + valueRemoved = true + + return reusedPair.withRows(currentKey, currentValue) } - } - } - def iterator(): Iterator[UnsafeRowPair] = { - val pair = new UnsafeRowPair() - keyWithIndexToValue.iterator.map { x => - pair.withRows(x.key, x.value) + override def close: Unit = {} } } @@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager( stateStore.get(keyWithIndexRow(key, valueIndex)) } - /** Get all the values for key and all indices. */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = { + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() var index = 0 - new NextIterator[UnsafeRow] { - override protected def getNext(): UnsafeRow = { + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { if (index >= numValues) { finished = true null } else { val keyWithIndex = keyWithIndexRow(key, index) val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, value) index += 1 - value + keyWithIndexAndValue } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ffa4c3c22a194..d44af1d14c27a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter BoundReference( 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), Literal(threshold)) - manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + val iter = manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + while (iter.hasNext) iter.next() } /** Remove values where `time <= threshold` */ def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) - manager.removeByValueCondition( + val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) + while (iter.hasNext) iter.next() } def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 533e1165fd59c..a6593b71e51de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -24,8 +24,9 @@ import scala.util.Random import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} @@ -35,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } - testQuietly("extract watermark from time condition") { - val attributesToFindConstraintFor = Seq( - AttributeReference("leftTime", TimestampType)(), - AttributeReference("leftOther", IntegerType)()) - val metadataWithWatermark = new MetadataBuilder() - .putLong(EventTimeWatermark.delayKey, 1000) - .build() - val attributesWithWatermark = Seq( - AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), - AttributeReference("rightOther", IntegerType)()) - - def watermarkFrom( - conditionStr: String, - rightWatermark: Option[Long] = Some(10000)): Option[Long] = { - val conditionExpr = Some(conditionStr).map { str => - val plan = - Filter( - spark.sessionState.sqlParser.parseExpression(str), - LogicalRDD( - attributesToFindConstraintFor ++ attributesWithWatermark, - spark.sparkContext.emptyRDD)(spark)) - plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition - } - StreamingSymmetricHashJoinHelper.getStateValueWatermark( - AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), - conditionExpr, rightWatermark) - } - - // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, - // then cannot define constraint on leftTime. - assert(watermarkFrom("leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) - assert(watermarkFrom("leftTime < rightTime") === None) - assert(watermarkFrom("leftTime <= rightTime") === None) - assert(watermarkFrom("rightTime > leftTime") === None) - assert(watermarkFrom("rightTime >= leftTime") === None) - assert(watermarkFrom("rightTime < leftTime") === Some(10000)) - assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) - - // Test type conversions - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) - - // Test with timestamp type + calendar interval on either side of equation - // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. - assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) - assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) - assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) - assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) - assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") - === Some(12000)) - - // Test with casted long type + constants on either side of equation - // Note: long type and constants commute, so more combinations to test. - // -- Constants on the right - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") - === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") - === Some(10100)) - assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") - === Some(10200)) - // -- Constants on the left - assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) - assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) - assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") - === Some(7000)) - assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") - === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") - === Some(10100)) - // -- Constants on both sides, mixed types - assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") - === Some(13000)) - - // Test multiple conditions, should return minimum watermark - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === - Some(7000)) // first condition wins - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === - Some(6000)) // second condition wins - - // Test invalid comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes - assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed - - // Test static comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) - } - test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -470,3 +366,189 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String): + (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult), + assertNumStateRows(total = 3, updated = 1) + ) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("left outer join with non-key condition violated on left") { + val (leftInput, simpleLeftDf) = setupStream("left", 2) + val (rightInput, simpleRightDf) = setupStream("right", 3) + + val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) + val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = left.join( + right, + left("key") === right("key") && left("window") === right("window") && + 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + // leftValue <= 10 should generate outer join rows even though it matches right keys + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 1, 2, 3), + CheckLastBatch(), + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 20), + CheckLastBatch( + Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows + AddData(leftInput, 40, 41), + AddData(rightInput, 40, 41), + CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), + AddData(leftInput, 70), + AddData(rightInput, 71), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 2), + AddData(rightInput, 70), + CheckLastBatch((70, 80, 140, 210)), + assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left + AddData(leftInput, 101, 102, 103), + AddData(rightInput, 101, 102, 103), + CheckLastBatch(), + AddData(leftInput, 1000), + AddData(rightInput, 1001), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 1000), + CheckLastBatch( + Row(1000, 1010, 2000, 3000), + Row(101, 110, 202, null), + Row(102, 110, 204, null), + Row(103, 110, 206, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } +} + From d54670192a6acd892d13b511dfb62390be6ad39c Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 4 Oct 2017 07:11:00 +0100 Subject: [PATCH 013/712] [SPARK-22193][SQL] Minor typo fix ## What changes were proposed in this pull request? [SPARK-22193][SQL] Minor typo fix in SortMergeJoinExec. Nothing major, but it bothered me going into.Hence fixing ## How was this patch tested? existing tests Author: Rekha Joshi Author: rjoshi2 Closes #19422 from rekhajoshm/SPARK-22193. --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 14de2dc23e3c0..4e02803552e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -402,7 +402,7 @@ case class SortMergeJoinExec( } } - private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => s""" |if (comp == 0) { @@ -463,7 +463,7 @@ case class SortMergeJoinExec( | continue; | } | if (!$matches.isEmpty()) { - | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } @@ -484,7 +484,7 @@ case class SortMergeJoinExec( | } | ${rightKeyVars.map(_.code).mkString("\n")} | } - | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { | $rightRow = null; | } else if (comp < 0) { From 64df08b64779bab629a8a90a3797d8bd70f61703 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Oct 2017 15:06:44 +0800 Subject: [PATCH 014/712] [SPARK-20783][SQL] Create ColumnVector to abstract existing compressed column (batch method) ## What changes were proposed in this pull request? This PR abstracts data compressed by `CompressibleColumnAccessor` using `ColumnVector` in batch method. When `ColumnAccessor.decompress` is called, `ColumnVector` will have uncompressed data. This batch decompress does not use `InternalRow` to reduce the number of memory accesses. As first step of this implementation, this JIRA supports primitive data types. Another PR will support array and other data types. This implementation decompress data in batch into uncompressed column batch, as rxin suggested at [here](https://github.com/apache/spark/pull/18468#issuecomment-316914076). Another implementation uses adapter approach [as cloud-fan suggested](https://github.com/apache/spark/pull/18468). ## How was this patch tested? Added test suites Author: Kazuaki Ishizaki Closes #18704 from kiszk/SPARK-20783a. --- .../execution/columnar/ColumnDictionary.java | 58 +++ .../vectorized/OffHeapColumnVector.java | 18 + .../vectorized/OnHeapColumnVector.java | 18 + .../vectorized/WritableColumnVector.java | 76 ++-- .../execution/columnar/ColumnAccessor.scala | 16 +- .../sql/execution/columnar/ColumnType.scala | 33 ++ .../CompressibleColumnAccessor.scala | 4 + .../compression/CompressionScheme.scala | 3 + .../compression/compressionSchemes.scala | 340 +++++++++++++++++- .../compression/BooleanBitSetSuite.scala | 52 +++ .../compression/DictionaryEncodingSuite.scala | 72 +++- .../compression/IntegralDeltaSuite.scala | 72 ++++ .../PassThroughEncodingSuite.scala | 189 ++++++++++ .../compression/RunLengthEncodingSuite.scala | 89 ++++- .../TestCompressibleColumnBuilder.scala | 9 +- .../vectorized/ColumnVectorSuite.scala | 183 +++++++++- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 1192 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java new file mode 100644 index 0000000000000..f1785853a94ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar; + +import org.apache.spark.sql.execution.vectorized.Dictionary; + +public final class ColumnDictionary implements Dictionary { + private int[] intDictionary; + private long[] longDictionary; + + public ColumnDictionary(int[] dictionary) { + this.intDictionary = dictionary; + } + + public ColumnDictionary(long[] dictionary) { + this.longDictionary = dictionary; + } + + @Override + public int decodeToInt(int id) { + return intDictionary[id]; + } + + @Override + public long decodeToLong(int id) { + return longDictionary[id]; + } + + @Override + public float decodeToFloat(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support float"); + } + + @Override + public double decodeToDouble(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support double"); + } + + @Override + public byte[] decodeToBinary(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support String"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 8cbc895506d91..a7522ebf5821a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -228,6 +228,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { null, data + 2 * rowId, count * 2); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -268,6 +274,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { null, data + 4 * rowId, count * 4); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { @@ -334,6 +346,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { null, data + 8 * rowId, count * 8); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 2725a29eeabe8..166a39e0fabd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -233,6 +233,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { System.arraycopy(src, srcIndex, shortData, rowId, count); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, + Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -272,6 +278,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { System.arraycopy(src, srcIndex, intData, rowId, count); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, + Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; @@ -332,6 +344,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { System.arraycopy(src, srcIndex, longData, rowId, count); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, + Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 163f2511e5f73..da72954ddc448 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -113,138 +113,156 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { protected abstract void reserveInternal(int capacity); /** - * Sets the value at rowId to null/not null. + * Sets null/not null to the value at rowId. */ public abstract void putNotNull(int rowId); public abstract void putNull(int rowId); /** - * Sets the values from [rowId, rowId + count) to null/not null. + * Sets null/not null to the values at [rowId, rowId + count). */ public abstract void putNulls(int rowId, int count); public abstract void putNotNulls(int rowId, int count); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putBoolean(int rowId, boolean value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBooleans(int rowId, int count, boolean value); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putByte(int rowId, byte value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBytes(int rowId, int count, byte value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putShort(int rowId, short value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putShorts(int rowId, int count, short value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets values from [src[srcIndex], src[srcIndex + count * 2]) to [rowId, rowId + count) + * The data in src must be 2-byte platform native endian shorts. + */ + public abstract void putShorts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets `value` to the value at rowId. */ public abstract void putInt(int rowId, int value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putInts(int rowId, int count, int value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putInts(int rowId, int count, int[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be 4-byte platform native endian ints. + */ + public abstract void putInts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) * The data in src must be 4-byte little endian ints. */ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putLong(int rowId, long value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putLongs(int rowId, int count, long value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be 8-byte platform native endian longs. + */ + public abstract void putLongs(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src + srcIndex, src + srcIndex + count * 8) to [rowId, rowId + count) * The data in src must be 8-byte little endian longs. */ public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putFloat(int rowId, float value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putFloats(int rowId, int count, float value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be ieee formatted floats in platform native endian. */ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putDouble(int rowId, double value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putDoubles(int rowId, int count, double value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be ieee formatted doubles in platform native endian. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -254,7 +272,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract void putArray(int rowId, int offset, int length); /** - * Sets the value at rowId to `value`. + * Sets values from [value + offset, value + offset + count) to the values at rowId. */ public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 6241b79d9affc..24c8ac81420cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -24,6 +24,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ /** @@ -62,6 +63,9 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer + + def getByteBuffer: ByteBuffer = + buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) @@ -122,7 +126,7 @@ private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[columnar] object ColumnAccessor { +private[sql] object ColumnAccessor { @tailrec def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) @@ -149,4 +153,14 @@ private[columnar] object ColumnAccessor { throw new Exception(s"not support type: $other") } } + + def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): + Unit = { + if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(columnVector, numRows) + } else { + throw new RuntimeException("Not support non-primitive type now") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 5cfb003e4f150..e9b150fd86095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -43,6 +43,12 @@ import org.apache.spark.unsafe.types.UTF8String * WARNING: This only works with HeapByteBuffer */ private[columnar] object ByteBufferHelper { + def getShort(buffer: ByteBuffer): Short = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -66,6 +72,33 @@ private[columnar] object ByteBufferHelper { buffer.position(pos + 8) Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) } + + def putShort(buffer: ByteBuffer, value: Short): Unit = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putInt(buffer: ByteBuffer, value: Int): Unit = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putLong(buffer: ByteBuffer, value: Long): Unit = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = { + val srcPos = src.position() + val dstPos = dst.position() + src.position(srcPos + len) + dst.position(dstPos + len) + Platform.copyMemory(src.array(), Platform.BYTE_ARRAY_OFFSET + srcPos, + dst.array(), Platform.BYTE_ARRAY_OFFSET + dstPos, len) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index e1d13ad0e94e5..774011f1e3de8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { @@ -36,4 +37,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = + decoder.decompress(columnVector, capacity) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 6e4f1c5b80684..f8aeba44257d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait Encoder[T <: AtomicType] { @@ -41,6 +42,8 @@ private[columnar] trait Decoder[T <: AtomicType] { def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit } private[columnar] trait CompressionScheme { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index ee99c90a751d9..bf00ad997c76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer +import java.nio.ByteOrder import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ @@ -61,6 +63,101 @@ private[columnar] case object PassThrough extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + private def putBooleans( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + for (i <- 0 until len) { + columnVector.putBoolean(pos + i, (buffer.get(bufferPos + i) != 0)) + } + } + + private def putBytes( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putBytes(pos, len, buffer.array, bufferPos) + } + + private def putShorts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putShorts(pos, len, buffer.array, bufferPos) + } + + private def putInts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putInts(pos, len, buffer.array, bufferPos) + } + + private def putLongs( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putLongs(pos, len, buffer.array, bufferPos) + } + + private def putFloats( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putFloats(pos, len, buffer.array, bufferPos) + } + + private def putDoubles( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putDoubles(pos, len, buffer.array, bufferPos) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + unitSize: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity + var pos = 0 + var seenNulls = 0 + var bufferPos = buffer.position + while (pos < capacity) { + if (pos != nextNullIndex) { + val len = nextNullIndex - pos + assert(len * unitSize < Int.MaxValue) + putFunction(columnVector, pos, bufferPos, len) + bufferPos += len * unitSize + pos += len + } else { + seenNulls += 1 + nextNullIndex = if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + capacity + } + columnVector.putNull(pos) + pos += 1 + } + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBooleans) + case _: ByteType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBytes) + case _: ShortType => + val unitSize = 2 + decompress0(columnVector, capacity, unitSize, putShorts) + case _: IntegerType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putInts) + case _: LongType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putLongs) + case _: FloatType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putFloats) + case _: DoubleType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putDoubles) + } + } } } @@ -169,6 +266,94 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { } override def hasNext: Boolean = valueCount < run || buffer.hasRemaining + + private def putBoolean(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putBoolean(pos, value == 1) + } + + private def getByte(buffer: ByteBuffer): Long = { + buffer.get().toLong + } + + private def putByte(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putByte(pos, value.toByte) + } + + private def getShort(buffer: ByteBuffer): Long = { + buffer.getShort().toLong + } + + private def putShort(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putShort(pos, value.toShort) + } + + private def getInt(buffer: ByteBuffer): Long = { + buffer.getInt().toLong + } + + private def putInt(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putInt(pos, value.toInt) + } + + private def getLong(buffer: ByteBuffer): Long = { + buffer.getLong() + } + + private def putLong(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putLong(pos, value) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + getFunction: (ByteBuffer) => Long, + putFunction: (WritableColumnVector, Int, Long) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + var runLocal = 0 + var valueCountLocal = 0 + var currentValueLocal: Long = 0 + + while (valueCountLocal < runLocal || (pos < capacity)) { + if (pos != nextNullIndex) { + if (valueCountLocal == runLocal) { + currentValueLocal = getFunction(buffer) + runLocal = ByteBufferHelper.getInt(buffer) + valueCountLocal = 1 + } else { + valueCountLocal += 1 + } + putFunction(columnVector, pos, currentValueLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + decompress0(columnVector, capacity, getByte, putBoolean) + case _: ByteType => + decompress0(columnVector, capacity, getByte, putByte) + case _: ShortType => + decompress0(columnVector, capacity, getShort, putShort) + case _: IntegerType => + decompress0(columnVector, capacity, getInt, putInt) + case _: LongType => + decompress0(columnVector, capacity, getLong, putLong) + case _ => throw new IllegalStateException("Not supported type in RunLengthEncoding.") + } + } } } @@ -266,11 +451,32 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private val dictionary: Array[Any] = { - val elementNum = ByteBufferHelper.getInt(buffer) - Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) + extends compression.Decoder[T] { + val elementNum = ByteBufferHelper.getInt(buffer) + private val dictionary: Array[Any] = new Array[Any](elementNum) + private var intDictionary: Array[Int] = null + private var longDictionary: Array[Long] = null + + columnType.dataType match { + case _: IntegerType => + intDictionary = new Array[Int](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Int] + intDictionary(i) = v + dictionary(i) = v + } + case _: LongType => + longDictionary = new Array[Long](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Long] + longDictionary(i) = v + dictionary(i) = v + } + case _: StringType => + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Any] + dictionary(i) = v + } } override def next(row: InternalRow, ordinal: Int): Unit = { @@ -278,6 +484,46 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + columnType.dataType match { + case _: IntegerType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(intDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + columnVector.putNull(pos) + } + pos += 1 + } + case _: LongType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(longDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.") + } + } } } @@ -368,6 +614,38 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { } override def hasNext: Boolean = visited < count + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val countLocal = count + var currentWordLocal: Long = 0 + var visitedLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (visitedLocal < countLocal) { + if (pos != nextNullIndex) { + val bit = visitedLocal % BITS_PER_LONG + + visitedLocal += 1 + if (bit == 0) { + currentWordLocal = ByteBufferHelper.getLong(buffer) + } + + columnVector.putBoolean(pos, ((currentWordLocal >> bit) & 1) != 0) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -448,6 +726,32 @@ private[columnar] case object IntDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getInt(buffer) } + columnVector.putInt(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -528,5 +832,31 @@ private[columnar] case object LongDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Long = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get() + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getLong(buffer) } + columnVector.putLong(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index d01bf911e3a77..2d71a42628dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.BooleanType class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ @@ -85,6 +87,36 @@ class BooleanBitSetSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(count: Int) { + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.getBoolean(0)) + + rows.foreach(builder.appendFrom(_, 0)) + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val columnVector = new OnHeapColumnVector(values.length, BooleanType) + decoder.decompress(columnVector, values.length) + + if (values.nonEmpty) { + values.zipWithIndex.foreach { case (b: Boolean, index: Int) => + assertResult(b, s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + } + } + } + test(s"$BooleanBitSet: empty") { skeleton(0) } @@ -104,4 +136,24 @@ class BooleanBitSetSuite extends SparkFunSuite { test(s"$BooleanBitSet: multiple words and 1 more bit") { skeleton(BITS_PER_LONG * 2 + 1) } + + test(s"$BooleanBitSet: empty for decompression()") { + skeletonForDecompress(0) + } + + test(s"$BooleanBitSet: less than 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG - 1) + } + + test(s"$BooleanBitSet: exactly 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG) + } + + test(s"$BooleanBitSet: multiple whole words for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2) + } + + test(s"$BooleanBitSet: multiple words and 1 more bit for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2 + 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 67139b13d7882..28950b74cf1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -23,16 +23,19 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { + val nullValue = -1 testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) - testDictionaryEncoding(new StringColumnStats, STRING) + testDictionaryEncoding(new StringColumnStats, STRING, false) def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -113,6 +116,58 @@ class DictionaryEncodingSuite extends SparkFunSuite { } } + def skeletonForDecompress(uniqueValueCount: Int, inputSeq: Seq[Int]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val dictValues = stableDistinct(inputSeq) + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = DictionaryEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { case (i: Any, index: Int) => + if (i == nullValue) { + assertResult(true, s"Wrong null ${index}-th position") { + columnVector.isNullAt(index) + } + } else { + columnType match { + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + } + } + } + } + test(s"$DictionaryEncoding with $typeName: empty") { skeleton(0, Seq.empty) } @@ -124,5 +179,18 @@ class DictionaryEncodingSuite extends SparkFunSuite { test(s"$DictionaryEncoding with $typeName: dictionary overflow") { skeleton(DictionaryEncoding.MAX_DICT_SIZE + 1, 0 to DictionaryEncoding.MAX_DICT_SIZE) } + + test(s"$DictionaryEncoding with $typeName: empty for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$DictionaryEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0, nullValue, 0, nullValue)) + } + + test(s"$DictionaryEncoding with $typeName: dictionary overflow for decompress()") { + skeletonForDecompress(DictionaryEncoding.MAX_DICT_SIZE + 2, + Seq(nullValue) ++ (0 to DictionaryEncoding.MAX_DICT_SIZE - 1) ++ Seq(nullValue)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 411d31fa0e29b..0d9f1fb0c02c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -21,9 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { + val nullValue = -1 testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) @@ -109,6 +111,53 @@ class IntegralDeltaSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(input: Seq[I#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, scheme) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = scheme.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => + fail("Unsupported type") + } + } + } + test(s"$scheme: empty column") { skeleton(Seq.empty) } @@ -127,5 +176,28 @@ class IntegralDeltaSuite extends SparkFunSuite { val input = Array.fill[Any](10000)(makeRandomValue(columnType)) skeleton(input.map(_.asInstanceOf[I#InternalType])) } + + + test(s"$scheme: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$scheme: simple case for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } + + test(s"$scheme: simple case with null for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala new file mode 100644 index 0000000000000..b6f0b5e6277b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.AtomicType + +class PassThroughSuite extends SparkFunSuite { + val nullValue = -1 + testPassThrough(new ByteColumnStats, BYTE) + testPassThrough(new ShortColumnStats, SHORT) + testPassThrough(new IntColumnStats, INT) + testPassThrough(new LongColumnStats, LONG) + testPassThrough(new FloatColumnStats, FLOAT) + testPassThrough(new DoubleColumnStats, DOUBLE) + + def testPassThrough[T <: AtomicType]( + columnStats: ColumnStats, + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def skeleton(input: Seq[T#InternalType]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + + input.map { value => + val row = new GenericInternalRow(1) + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + + val buffer = builder.build() + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + input.size * columnType.defaultSize + + // 4 extra bytes for compression scheme type ID + assertResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + if (input.nonEmpty) { + input.foreach { value => + assertResult(value, "Wrong value")(columnType.extract(buffer)) + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = PassThrough.decoder(buffer, columnType) + val mutableRow = new GenericInternalRow(1) + + if (input.nonEmpty) { + input.foreach{ + assert(decoder.hasNext) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } + } + } + assert(!decoder.hasNext) + } + + def skeletonForDecompress(input: Seq[T#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = PassThrough.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Byte, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case (expected: Short, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case (expected: Float, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded float value") { + columnVector.getFloat(index) + } + case (expected: Double, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded double value") { + columnVector.getDouble(index) + } + case _ => fail("Unsupported type") + } + } + } + + test(s"$PassThrough with $typeName: empty column") { + skeleton(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeleton(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series for decompress()") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: simple case with null for decompress()") { + val input = columnType match { + case BYTE => Seq(2: Byte, 1: Byte, 2: Byte, nullValue.toByte: Byte, 5: Byte) + case SHORT => Seq(2: Short, 1: Short, 2: Short, nullValue.toShort: Short, 5: Short) + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + case FLOAT => Seq(2: Float, 1: Float, 2: Float, nullValue: Float, 5: Float) + case DOUBLE => Seq(2: Double, 1: Double, 2: Double, nullValue: Double, 5: Double) + } + + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index dffa9b364ebfe..eb1cdd9bbceff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -21,19 +21,22 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { + val nullValue = -1 testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new StringColumnStats, STRING, false) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -95,6 +98,72 @@ class RunLengthEncodingSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val inputSeq = inputRuns.flatMap { case (index, run) => + Seq.fill(run)(index) + } + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = RunLengthEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (i: Int, index: Int) => + columnType match { + case BOOLEAN => + assertResult(values(i), s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + case BYTE => + assertResult(values(i), s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case SHORT => + assertResult(values(i), s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + case _ => fail("Unsupported type") + } + } + } + test(s"$RunLengthEncoding with $typeName: empty column") { skeleton(0, Seq.empty) } @@ -110,5 +179,21 @@ class RunLengthEncodingSuite extends SparkFunSuite { test(s"$RunLengthEncoding with $typeName: single long run") { skeleton(1, Seq(0 -> 1000)) } + + test(s"$RunLengthEncoding with $typeName: empty column for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$RunLengthEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, 1 -> 2)) + } + + test(s"$RunLengthEncoding with $typeName: single long run for decompress()") { + skeletonForDecompress(1, Seq(0 -> 1000)) + } + + test(s"$RunLengthEncoding with $typeName: single case with null for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, nullValue -> 2)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5e078f251375a..310cb0be5f5a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.types.AtomicType +import org.apache.spark.sql.types.{AtomicType, DataType} class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, @@ -42,3 +42,10 @@ object TestCompressibleColumnBuilder { builder } } + +object ColumnBuilderHelper { + def apply( + dataType: DataType, batchSize: Int, name: String, useCompression: Boolean): ColumnBuilder = { + ColumnBuilder(dataType, batchSize, name, useCompression) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 85da8270d4cba..c5c8ae3a17c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql.execution.vectorized import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.columnar.ColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -31,14 +34,21 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { try block(vector) finally vector.close() } + private def withVectors( + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) + } + private def testVectors( name: String, size: Int, dt: DataType)( block: WritableColumnVector => Unit): Unit = { test(name) { - withVector(new OnHeapColumnVector(size, dt))(block) - withVector(new OffHeapColumnVector(size, dt))(block) + withVectors(size, dt)(block) } } @@ -218,4 +228,173 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) } } + + test("CachedBatch boolean Apis") { + val dataType = BooleanType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setBoolean(0, i % 2 == 0) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getBoolean(i) == (i % 2 == 0)) + } + } + } + + test("CachedBatch byte Apis") { + val dataType = ByteType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setByte(0, i.toByte) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getByte(i) == i) + } + } + } + + test("CachedBatch short Apis") { + val dataType = ShortType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setShort(0, i.toShort) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getShort(i) == i) + } + } + } + + test("CachedBatch int Apis") { + val dataType = IntegerType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setInt(0, i) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getInt(i) == i) + } + } + } + + test("CachedBatch long Apis") { + val dataType = LongType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setLong(0, i.toLong) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getLong(i) == i.toLong) + } + } + } + + test("CachedBatch float Apis") { + val dataType = FloatType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setFloat(0, i.toFloat) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getFloat(i) == i.toFloat) + } + } + } + + test("CachedBatch double Apis") { + val dataType = DoubleType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setDouble(0, i.toDouble) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getDouble(i) == i.toDouble) + } + } + } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 983eb103682c1..0b179aa97c479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -413,7 +413,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) + " Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) @@ -1120,7 +1120,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } batch.close() } - }} + }} /** * This test generates a random schema data, serializes it to column batches and verifies the From 4a779bdac3e75c17b7d36c5a009ba6c948fa9fb6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Oct 2017 10:08:24 -0700 Subject: [PATCH 015/712] [SPARK-21871][SQL] Check actual bytecode size when compiling generated code ## What changes were proposed in this pull request? This pr added code to check actual bytecode size when compiling generated code. In #18810, we added code to give up code compilation and use interpreter execution in `SparkPlan` if the line number of generated functions goes over `maxLinesPerFunction`. But, we already have code to collect metrics for compiled bytecode size in `CodeGenerator` object. So,we could easily reuse the code for this purpose. ## How was this patch tested? Added tests in `WholeStageCodegenSuite`. Author: Takeshi Yamamuro Closes #19083 from maropu/SPARK-21871. --- .../expressions/codegen/CodeFormatter.scala | 8 --- .../expressions/codegen/CodeGenerator.scala | 59 +++++++++---------- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateOrdering.scala | 3 +- .../codegen/GeneratePredicate.scala | 3 +- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 15 ++--- .../codegen/CodeFormatterSuite.scala | 32 ---------- .../sql/execution/WholeStageCodegenExec.scala | 25 ++++---- .../columnar/GenerateColumnAccessor.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 43 ++++---------- .../benchmark/AggregateBenchmark.scala | 36 +++++------ 14 files changed, 94 insertions(+), 149 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..60e600d8dbd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,14 +89,6 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } - - def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines - } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3b45799c5688..f9c5ef8439085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -373,20 +373,6 @@ class CodegenContext { */ private val placeHolderToComments = new mutable.HashMap[String, String] - /** - * It will count the lines of every Java function generated by whole-stage codegen, - * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, - * it will return true. - */ - def isTooLongGeneratedFunction: Boolean = { - classFunctions.values.exists { _.values.exists { - code => - val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) - codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction - } - } - } - /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ @@ -1020,10 +1006,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } object CodeGenerator extends Logging { + + // This is the value of HugeMethodLimit in the OpenJDK JVM settings + val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + /** * Compile the Java source code into a Java class, using Janino. + * + * @return a pair of a generated class and the max bytecode size of generated functions. */ - def compile(code: CodeAndComment): GeneratedClass = try { + def compile(code: CodeAndComment): (GeneratedClass, Int) = try { cache.get(code) } catch { // Cache.get() may wrap the original exception. See the following URL @@ -1036,7 +1028,7 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: CodeAndComment): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = { val evaluator = new ClassBodyEvaluator() // A special classloader used to wrap the actual parent classloader of @@ -1075,9 +1067,9 @@ object CodeGenerator extends Logging { s"\n${CodeFormatter.format(code)}" }) - try { + val maxCodeSize = try { evaluator.cook("generated.java", code.body) - recordCompilationStats(evaluator) + updateAndGetCompilationStats(evaluator) } catch { case e: JaninoRuntimeException => val msg = s"failed to compile: $e" @@ -1092,13 +1084,15 @@ object CodeGenerator extends Logging { logInfo(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } - evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] + + (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** - * Records the generated class and method bytecode sizes by inspecting janino private fields. + * Returns the max bytecode size of the generated functions by inspecting janino private fields. + * Also, this method updates the metrics information. */ - private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = { + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { // First retrieve the generated classes. val classes = { val resultField = classOf[SimpleCompiler].getDeclaredField("result") @@ -1113,23 +1107,26 @@ object CodeGenerator extends Logging { val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") val codeAttrField = codeAttr.getDeclaredField("code") codeAttrField.setAccessible(true) - classes.foreach { case (_, classBytes) => + val codeSizes = classes.flatMap { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) - } + val stats = cf.methodInfos.asScala.flatMap { method => + method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + byteCodeSize } } + Some(stats) } catch { case NonFatal(e) => logWarning("Error calculating stats of compiled class.", e) + None } - } + }.flatten + + codeSizes.max } /** @@ -1144,8 +1141,8 @@ object CodeGenerator extends Logging { private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[CodeAndComment, GeneratedClass]() { - override def load(code: CodeAndComment): GeneratedClass = { + new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { + override def load(code: CodeAndComment): (GeneratedClass, Int) = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3768dcde00a4e..b5429fade53cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -142,7 +142,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4e47895985209..1639d1b9dda1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -185,7 +185,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e35b9dda6c017..e0fabad6d089a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -78,6 +78,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 192701a829686..1e4ac3f2afd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -189,8 +189,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) + val (clazz, _) = CodeGenerator.compile(code) val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) - c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] + clazz.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index f2a66efc98e71..4bd50aee05514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -409,7 +409,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 4aa5ec82471ec..6bc72a0d75c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty)) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1a73d168b9b6e..58323740b80cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -575,15 +576,15 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) - val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction") + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() - .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + - "When the generated function exceeds this threshold, " + + .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + + "codegen. When the compiled function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - "The default value 4000 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 2.") + s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + + "this is a limit in the OpenJDK JVM implementation.") .intConf - .createWithDefault(4000) + .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") @@ -1058,7 +1059,7 @@ class SQLConf extends Serializable with Logging { def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) - def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION) + def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index a0f1a64b0ab08..9d0a41661beaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,38 +53,6 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } - test("removing extra new lines and comments") { - val code = - """ - |/* - | * multi - | * line - | * comments - | */ - | - |public function() { - |/*comment*/ - | /*comment_with_space*/ - |code_body - |//comment - |code_body - | //comment_with_space - | - |code_body - |} - """.stripMargin - - val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) - assert(reducedCode === - """ - |public function() { - |code_body - |code_body - |code_body - |} - """.stripMargin) - } - testCase("basic example") { """ |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 268ccfa4edfa0..9073d599ac43d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -380,16 +380,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() - if (ctx.isTooLongGeneratedFunction) { - logWarning("Found too long generated codes and JIT optimization might not work, " + - "Whole-stage codegen disabled for this plan, " + - "You can change the config spark.sql.codegen.MaxFunctionLength " + - "to adjust the function length limit:\n " - + s"$treeString") - return child.execute() - } // try to compile and fallback if it failed - try { + val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => @@ -397,6 +389,17 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") return child.execute() } + + // Check if compiled code has a too large function + if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { + logWarning(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size was $maxCodeSize, this value went over the limit " + + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + + s"for this plan. To avoid this, you can raise the limit " + + s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") + return child.execute() + } + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") @@ -405,7 +408,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { @@ -424,7 +427,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index da34643281911..ae600c1ffae8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -227,6 +227,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index beeee6a97c8dd..aaa77b3ee6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -151,7 +149,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } } - def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" }.toList @@ -176,34 +174,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { }) assert(wholeStageCodeGenExec.isDefined) - wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21603 check there is a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === true) - } - } - - test("SPARK-21603 check there is not a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is a too long generated function when threshold is 0") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === true) - } + test("SPARK-21871 check if we can get large code size when compiling too long functions") { + val codeWithShortFunctions = genGroupByCode(3) + val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) + assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val codeWithLongFunctions = genGroupByCode(20) + val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) + assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 691fa9ac5e1e7..aca1be01fa3da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -24,6 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 @@ -301,10 +302,10 @@ class AggregateBenchmark extends BenchmarkBase { */ } - ignore("max function length of wholestagecodegen") { + ignore("max function bytecode size of wholestagecodegen") { val N = 20 << 15 - val benchmark = new Benchmark("max function length of wholestagecodegen", N) + val benchmark = new Benchmark("max function bytecode size", N) def f(): Unit = sparkSession.range(N) .selectExpr( "id", @@ -333,33 +334,34 @@ class AggregateBenchmark extends BenchmarkBase { .sum() .collect() - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + benchmark.addCase("codegen = F") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + benchmark.addCase("codegen = T hugeMethodLimit = 10000") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "10000") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + benchmark.addCase("codegen = T hugeMethodLimit = 1500") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "1500") f() } benchmark.run() /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 - Intel64 Family 6 Model 58 Stepping 9, GenuineIntel - max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ---------------------------------------------------------------------------------------------- - codegen = F 462 / 533 1.4 704.4 1.0X - codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X - codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + + max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 709 / 803 0.9 1082.1 1.0X + codegen = T hugeMethodLimit = 10000 3485 / 3548 0.2 5317.7 0.2X + codegen = T hugeMethodLimit = 1500 636 / 701 1.0 969.9 1.1X */ } From bb035f1ee5cdf88e476b7ed83d59140d669fbe12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Oct 2017 13:13:51 -0700 Subject: [PATCH 016/712] [SPARK-22169][SQL] support byte length literal as identifier ## What changes were proposed in this pull request? By definition the table name in Spark can be something like `123x`, `25a`, etc., with exceptions for literals like `12L`, `23BD`, etc. However, Spark SQL has a special byte length literal, which stops users to use digits followed by `b`, `k`, `m`, `g` as identifiers. byte length literal is not a standard sql literal and is only used in the `tableSample` parser rule. This PR move the parsing of byte length literal from lexer to parser, so that users can use it as identifiers. ## How was this patch tested? regression test Author: Wenchen Fan Closes #19392 from cloud-fan/parser-bug. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 25 +++++++----------- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 26 +++++++++++++------ .../sql/catalyst/parser/PlanParserSuite.scala | 1 + .../execution/command/DDLParserSuite.scala | 19 ++++++++++++++ 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d0a54288780ea..17c8404f8a79c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -25,7 +25,7 @@ grammar SqlBase; * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. - * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' * which is not a digit or letter or underscore. */ @@ -40,10 +40,6 @@ grammar SqlBase; } } -tokens { - DELIMITER -} - singleStatement : statement EOF ; @@ -447,12 +443,15 @@ joinCriteria ; sample - : TABLESAMPLE '(' - ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | sampleType=BYTELENGTH_LITERAL - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) - ')' + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes ; identifierList @@ -1004,10 +1003,6 @@ TINYINT_LITERAL : DIGIT+ 'Y' ; -BYTELENGTH_LITERAL - : DIGIT+ ('B' | 'K' | 'M' | 'G') - ; - INTEGER_VALUE : DIGIT+ ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6ba9ee5446a01..95bc3d674b4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -99,7 +99,7 @@ class SessionCatalog( protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** - * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), + * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. * * This method is intended to have the same behavior of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 85b492e83446e..ce367145bc637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -699,20 +699,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) - case SqlBaseParser.PERCENTLIT => + case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 sample(sign * fraction / 100.0d) - case SqlBaseParser.BYTELENGTH_LITERAL => - throw new ParseException( - "TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } - case SqlBaseParser.BUCKET if ctx.ON != null => + case ctx: SampleByBucketContext if ctx.ON() != null => if (ctx.identifier != null) { throw new ParseException( "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) @@ -721,7 +731,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) } - case SqlBaseParser.BUCKET => + case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 306e6f2cfbd37..d34a83c42c67e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -110,6 +110,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) + assertEqual("select a from 1k.2m", table("1k", "2m").select('a)) } test("reverse select query") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index fa5172ca8a3e7..eb7c33590b602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -525,6 +525,25 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.message.contains("you can only specify one of them.")) } + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { From 969ffd631746125eb2b83722baf6f6e7ddd2092c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 4 Oct 2017 19:25:22 -0700 Subject: [PATCH 017/712] [SPARK-22187][SS] Update unsaferow format for saved state such that we can set timeouts when state is null ## What changes were proposed in this pull request? Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion. This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases. ## How was this patch tested? Refactored tests. Author: Tathagata Das Closes #19416 from tdas/SPARK-22187. --- .../FlatMapGroupsWithStateExec.scala | 133 +++------------ .../FlatMapGroupsWithState_StateManager.scala | 153 ++++++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 130 ++++++++------- 3 files changed, 246 insertions(+), 170 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index ab690fd5fbbca..aab06d611a5ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -62,26 +60,7 @@ case class FlatMapGroupsWithStateExec( import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - + val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -109,11 +88,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -128,7 +107,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -143,7 +122,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -152,14 +131,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -168,20 +139,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -190,12 +160,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutKeys = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutKeys.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -205,72 +174,43 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: FlatMapGroupsWithState_StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -279,28 +219,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala new file mode 100644 index 0000000000000..d077836da847c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} + + +/** + * Class to serialize/write/read/deserialize state for + * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. + */ +class FlatMapGroupsWithState_StateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends Serializable { + + /** Schema of the state rows saved in the state store */ + val stateSchema = { + val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema + } + + /** Get deserialized state and corresponding timeout timestamp for a key */ + def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + /** Removed all information related to a key */ + def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + /** Get all the keys and corresponding state rows in the state store */ + def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = { + val stateDataForGetAllState = FlatMapGroupsWithState_StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + // Ordinals of the information stored in the state row + private lazy val nestedStateOrdinal = 0 + private lazy val timeoutTimestampOrdinal = 1 + + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val nestedStateExpr = CreateNamedStruct( + stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + if (shouldStoreTimestamp) { + Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nestedStateExpr) + } + } + + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = { + val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) + val deser = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + } + + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateSchema.toAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Reusable instance for returning state information + private lazy val stateDataForGets = FlatMapGroupsWithState_StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow == null) null + else getStateObjFromRow(stateRow) + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + val row = getStateRowFromObj(obj) + if (obj == null) { + row.setNullAt(nestedStateOrdinal) + } + row + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } +} + +/** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ +case class FlatMapGroupsWithState_StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d74a5c701ef1..d2e8beb2f5290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -289,13 +289,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -322,7 +322,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -365,6 +365,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -375,10 +387,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { + state.update(5); state.setTimeoutTimestamp(5000) + }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -397,50 +435,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -924,7 +935,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -946,7 +957,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -973,21 +984,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -998,15 +1008,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } From c8affec21c91d638009524955515fc143ad86f20 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 4 Oct 2017 20:58:48 -0700 Subject: [PATCH 018/712] [SPARK-22203][SQL] Add job description for file listing Spark jobs ## What changes were proposed in this pull request? The user may be confused about some 10000-tasks jobs. We can add a job description for these jobs so that the user can figure it out. ## How was this patch tested? The new unit test. Before: screen shot 2017-10-04 at 3 22 09 pm After: screen shot 2017-10-04 at 3 13 51 pm Author: Shixiong Zhu Closes #19432 from zsxwing/SPARK-22203. --- .../datasources/InMemoryFileIndex.scala | 85 +++++++++++-------- .../sql/test/DataFrameReaderWriterSuite.scala | 31 +++++++ 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 203d449717512..318ada0ceefc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession @@ -187,42 +188,56 @@ object InMemoryFileIndex extends Logging { // in case of large #defaultParallelism. val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) + val previousJobDescription = sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + val statusMap = try { + val description = paths.size match { + case 0 => + s"Listing leaf files and directories 0 paths" + case 1 => + s"Listing leaf files and directories for 1 path:
${paths(0)}" + case s => + s"Listing leaf files and directories for $s paths:
${paths(0)}, ..." } - (path.toString, serializableStatuses) - }.collect() + sparkContext.setJobDescription(description) + sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + } finally { + sparkContext.setJobDescription(previousJobDescription) + } // turn SerializableFileStatus back to Status statusMap.map { case (path, serializableStatuses) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 569bac156b531..a5d7e6257a6df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,10 +21,14 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf @@ -775,4 +779,31 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("use Spark jobs to list files") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + withTempDir { dir => + val jobDescriptions = new ConcurrentLinkedQueue[String]() + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescriptions.add(jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + } + sparkContext.addSparkListener(jobListener) + try { + spark.range(0, 3).map(i => (i, i)) + .write.partitionBy("_1").mode("overwrite").parquet(dir.getCanonicalPath) + // normal file paths + checkDatasetUnorderly( + spark.read.parquet(dir.getCanonicalPath).as[(Long, Long)], + 0L -> 0L, 1L -> 1L, 2L -> 2L) + sparkContext.listenerBus.waitUntilEmpty(10000) + assert(jobDescriptions.asScala.toList.exists( + _.contains("Listing leaf files and directories for 3 paths"))) + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + } + } } From ae61f187aa0471242c046fdeac6ed55b9b98a3f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Oct 2017 23:36:18 +0900 Subject: [PATCH 019/712] [SPARK-22206][SQL][SPARKR] gapply in R can't work on empty grouping columns ## What changes were proposed in this pull request? Looks like `FlatMapGroupsInRExec.requiredChildDistribution` didn't consider empty grouping attributes. It should be a problem when running `EnsureRequirements` and `gapply` in R can't work on empty grouping columns. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19436 from viirya/fix-flatmapinr-distribution. --- R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ .../main/scala/org/apache/spark/sql/execution/objects.scala | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7f781f2f66a7f..bbea25bc4da5c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3075,6 +3075,11 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x }) expect_identical(df1Collect, expected) + # gapply on empty grouping columns. + df1 <- gapply(df, c(), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5a3fcad38888e..c68975bea490f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -394,7 +394,11 @@ case class FlatMapGroupsInRExec( override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) From 83488cc3180ca18f829516f550766efb3095881e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 5 Oct 2017 23:33:49 -0700 Subject: [PATCH 020/712] [SPARK-21871][SQL] Fix infinite loop when bytecode size is larger than spark.sql.codegen.hugeMethodLimit ## What changes were proposed in this pull request? When exceeding `spark.sql.codegen.hugeMethodLimit`, the runtime fallbacks to the Volcano iterator solution. This could cause an infinite loop when `FileSourceScanExec` can use the columnar batch to read the data. This PR is to fix the issue. ## How was this patch tested? Added a test Author: gatorsmile Closes #19440 from gatorsmile/testt. --- .../sql/execution/WholeStageCodegenExec.scala | 12 ++++++---- .../execution/WholeStageCodegenSuite.scala | 23 +++++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9073d599ac43d..1aaaf896692d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -392,12 +392,16 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // Check if compiled code has a too large function if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { - logWarning(s"Found too long generated codes and JIT optimization might not work: " + - s"the bytecode size was $maxCodeSize, this value went over the limit " + + logInfo(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan. To avoid this, you can raise the limit " + - s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString") - return child.execute() + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") + child match { + // The fallback solution of batch file source scan still uses WholeStageCodegenExec + case f: FileSourceScanExec if f.supportsBatch => // do nothing + case _ => return child.execute() + } } val references = ctx.references.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index aaa77b3ee6201..098e4cfeb15b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { +class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") @@ -185,4 +185,23 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } + + test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "2000") { + // wide table batch scan causes the byte code of codegen exceeds the limit of + // WHOLESTAGE_HUGE_METHOD_LIMIT + val df2 = spark.read.parquet(path) + val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + assert(fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) + checkAnswer(df2, df) + } + } + } } From 0c03297bf0e87944f9fe0535fdae5518228e3e29 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 6 Oct 2017 15:08:28 +0100 Subject: [PATCH 021/712] [SPARK-22142][BUILD][STREAMING] Move Flume support behind a profile, take 2 ## What changes were proposed in this pull request? Move flume behind a profile, take 2. See https://github.com/apache/spark/pull/19365 for most of the back-story. This change should fix the problem by removing the examples module dependency and moving Flume examples to the module itself. It also adds deprecation messages, per a discussion on dev about deprecating for 2.3.0. ## How was this patch tested? Existing tests, which still enable flume integration. Author: Sean Owen Closes #19412 from srowen/SPARK-22142.2. --- dev/create-release/release-build.sh | 4 ++-- dev/mima | 2 +- dev/scalastyle | 1 + dev/sparktestsupport/modules.py | 20 ++++++++++++++++++- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 7 +++++++ docs/streaming-flume-integration.md | 13 +++++------- examples/pom.xml | 7 ------- .../spark/examples}/JavaFlumeEventCount.java | 2 -- .../spark/examples}/FlumeEventCount.scala | 2 -- .../examples}/FlumePollingEventCount.scala | 2 -- .../spark/streaming/flume/FlumeUtils.scala | 1 + pom.xml | 13 +++++++++--- project/SparkBuild.scala | 17 ++++++++-------- python/pyspark/streaming/flume.py | 4 ++++ python/pyspark/streaming/tests.py | 16 ++++++++++++--- 16 files changed, 73 insertions(+), 40 deletions(-) rename {examples/src/main/java/org/apache/spark/examples/streaming => external/flume/src/main/java/org/apache/spark/examples}/JavaFlumeEventCount.java (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumeEventCount.scala (98%) rename {examples/src/main/scala/org/apache/spark/examples/streaming => external/flume/src/main/scala/org/apache/spark/examples}/FlumePollingEventCount.scala (98%) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5390f5916fc0d..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/mima b/dev/mima index fdb21f5007cf2..1e3ca9700bc07 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index e5aa589869535..89ecc8abd6f8c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,6 +25,7 @@ ERRORS=$(echo -e "q\n" \ -Pmesos \ -Pkafka-0-8 \ -Pyarn \ + -Pflume \ -Phive \ -Phive-thriftserver \ scalastyle test:scalastyle \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 50e14b60545af..91d5667ed1f07 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -279,6 +279,12 @@ def __hash__(self): source_file_regexes=[ "external/flume-sink", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume-sink/test", ] @@ -291,6 +297,12 @@ def __hash__(self): source_file_regexes=[ "external/flume", ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + }, sbt_test_goals=[ "streaming-flume/test", ] @@ -302,7 +314,13 @@ def __hash__(self): dependencies=[streaming_flume, streaming_flume_sink], source_file_regexes=[ "external/flume-assembly", - ] + ], + build_profile_flags=[ + "-Pflume", + ], + environ={ + "ENABLE_FLUME_TESTS": "1" + } ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..98f7df155456f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,13 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. +Note: Flume support is deprecated as of Spark 2.3.0. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index a5d36da5b6de9..257a4f7d4f3ca 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +**Note: Flume support is deprecated as of Spark 2.3.0.** + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -44,8 +46,7 @@ configuring Flume agents. val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$).
import org.apache.spark.streaming.flume.*; @@ -53,8 +54,7 @@ configuring Flume agents. JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); - See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html).
from pyspark.streaming.flume import FlumeUtils @@ -62,8 +62,7 @@ configuring Flume agents. flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils).
@@ -162,8 +161,6 @@ configuring Flume agents. - See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). - Note that each input DStream can be configured to receive data from multiple sinks. 3. **Deploying:** This is same as the first approach. diff --git a/examples/pom.xml b/examples/pom.xml index 52a6764ae26a5..1791dbaad775e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,7 +34,6 @@ examples none package - provided provided provided provided @@ -78,12 +77,6 @@ ${project.version} provided - - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-streaming-kafka-0-10_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java similarity index 98% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java rename to external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java index 0c651049d0ffa..4e3420d9c3b06 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java @@ -48,8 +48,6 @@ public static void main(String[] args) throws Exception { System.exit(1); } - StreamingExamples.setStreamingLogLevels(); - String host = args[0]; int port = Integer.parseInt(args[1]); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala index 91e52e4eff5a7..f877f79391b37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala @@ -47,8 +47,6 @@ object FlumeEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala index dd725d72c23ef..79a4027ca5bde 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala @@ -44,8 +44,6 @@ object FlumePollingEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 3e3ed712f0dbf..707193a957700 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -30,6 +30,7 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream +@deprecated("Deprecated without replacement", "2.3.0") object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/pom.xml b/pom.xml index 87a468c3a6f55..9fac8b1e53788 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -2583,6 +2581,15 @@ + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cd30483fc636a..2fed5940b31ea 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -53,6 +53,8 @@ def createStream(ssc, hostname, port, :param enableDecompression: Should netty server decompress input stream :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) @@ -79,6 +81,8 @@ def createPollingStream(ssc, addresses, will result in this stream using more threads :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0 """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: From c7b46d4d8aa8da24131d79d2bfa36e8db19662e4 Mon Sep 17 00:00:00 2001 From: minixalpha Date: Fri, 6 Oct 2017 23:38:47 +0900 Subject: [PATCH 022/712] [SPARK-21877][DEPLOY, WINDOWS] Handle quotes in Windows command scripts ## What changes were proposed in this pull request? All the windows command scripts can not handle quotes in parameter. Run a windows command shell with parameter which has quotes can reproduce the bug: ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " 'C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7\bin\spark-shell2.cmd" --driver-java-options "' is not recognized as an internal or external command, operable program or batch file. ``` Windows recognize "--driver-java-options" as part of the command. All the Windows command script has the following code have the bug. ``` cmd /V /E /C "" %* ``` We should quote command and parameters like ``` cmd /V /E /C """ %*" ``` ## How was this patch tested? Test manually on Windows 10 and Windows 7 We can verify it by the following demo: ``` C:\Users\meng\program\demo>cat a.cmd echo off cmd /V /E /C "b.cmd" %* C:\Users\meng\program\demo>cat b.cmd echo off echo %* C:\Users\meng\program\demo>cat c.cmd echo off cmd /V /E /C ""b.cmd" %*" C:\Users\meng\program\demo>a.cmd "123" 'b.cmd" "123' is not recognized as an internal or external command, operable program or batch file. C:\Users\meng\program\demo>c.cmd "123" "123" ``` With the spark-shell.cmd example, change it to the following code will make the command execute succeed. ``` cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" ``` ``` C:\Users\meng\software\spark-2.2.0-bin-hadoop2.7> bin\spark-shell --driver-java-options " -Dfile.encoding=utf-8 " Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). ... ``` Author: minixalpha Closes #19090 from minixalpha/master. --- bin/beeline.cmd | 4 +++- bin/pyspark.cmd | 4 +++- bin/run-example.cmd | 5 ++++- bin/spark-class.cmd | 4 +++- bin/spark-shell.cmd | 4 +++- bin/spark-submit.cmd | 4 +++- bin/sparkR.cmd | 4 +++- 7 files changed, 22 insertions(+), 7 deletions(-) diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 02464bd088792..288059a28cd74 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -17,4 +17,6 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -cmd /V /E /C "%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %*" diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 72d046a4ba2cf..3dcf1d45a8189 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0pyspark2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0pyspark2.cmd" %*" diff --git a/bin/run-example.cmd b/bin/run-example.cmd index f9b786e92b823..efa5f81d08f7f 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -19,4 +19,7 @@ rem set SPARK_HOME=%~dp0.. set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] -cmd /V /E /C "%~dp0spark-submit.cmd" run-example %* + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit.cmd" run-example %*" diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 3bf3d20cb57b5..b22536ab6f458 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-class2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class2.cmd" %*" diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 991423da6ab99..e734f13097d61 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-shell2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f301606933a95..da62a8777524d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-submit2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit2.cmd" %*" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index 1e5ea6a623219..fcd172b083e1e 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0sparkR2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0sparkR2.cmd" %*" From 08b204fd2c731e87d3bc2cc0bccb6339ef7e3a6e Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 6 Oct 2017 12:53:35 -0700 Subject: [PATCH 023/712] [SPARK-22214][SQL] Refactor the list hive partitions code ## What changes were proposed in this pull request? In this PR we make a few changes to the list hive partitions code, to make the code more extensible. The following changes are made: 1. In `HiveClientImpl.getPartitions()`, call `client.getPartitions` instead of `shim.getAllPartitions` when `spec` is empty; 2. In `HiveTableScanExec`, previously we always call `listPartitionsByFilter` if the config `metastorePartitionPruning` is enabled, but actually, we'd better call `listPartitions` if `partitionPruningPred` is empty; 3. We should use sessionCatalog instead of SharedState.externalCatalog in `HiveTableScanExec`. ## How was this patch tested? Tested by existing test cases since this is code refactor, no regression or behavior change is expected. Author: Xingbo Jiang Closes #19444 from jiangxb1987/hivePartitions. --- .../sql/catalyst/catalog/interface.scala | 5 ++++ .../sql/hive/client/HiveClientImpl.scala | 7 +++-- .../hive/execution/HiveTableScanExec.scala | 28 +++++++++---------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index fe2af910a0ae5..975b084aa6188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -405,6 +405,11 @@ object CatalogTypes { * Specifications of a table partition. Mapping column name to column value. */ type TablePartitionSpec = Map[String, String] + + /** + * Initialize an empty spec. + */ + lazy val emptyTablePartitionSpec: TablePartitionSpec = Map.empty[String, String] } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 66165c7228bca..a01c312d5e497 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -638,12 +638,13 @@ private[hive] class HiveClientImpl( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val parts = spec match { - case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) + val partSpec = spec match { + case None => CatalogTypes.emptyTablePartitionSpec case Some(s) => assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") - client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) + s } + val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 48d0b4a63e54a..4f8dab9cd6172 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -162,21 +162,19 @@ case class HiveTableScanExec( // exposed for tests @transient lazy val rawPartitions = { - val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - // Retrieve the original attributes based on expression ID so that capitalization matches. - val normalizedFilters = partitionPruningPred.map(_.transform { - case a: AttributeReference => originalAttributes(a) - }) - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - relation.tableMeta.database, - relation.tableMeta.identifier.table, - normalizedFilters, - sparkSession.sessionState.conf.sessionLocalTimeZone) - } else { - sparkSession.sharedState.externalCatalog.listPartitions( - relation.tableMeta.database, - relation.tableMeta.identifier.table) - } + val prunedPartitions = + if (sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.size > 0) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, + normalizedFilters) + } else { + sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) + } prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) } From debcbec7491d3a23b19ef149e50d2887590b6de0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 Oct 2017 13:10:04 -0700 Subject: [PATCH 024/712] [SPARK-21947][SS] Check and report error when monotonically_increasing_id is used in streaming query ## What changes were proposed in this pull request? `monotonically_increasing_id` doesn't work in Structured Streaming. We should throw an exception if a streaming query uses it. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19336 from viirya/SPARK-21947. --- .../analysis/UnsupportedOperationChecker.scala | 15 ++++++++++++++- .../analysis/UnsupportedOperationsSuite.scala | 10 +++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index dee6fbe9d1514..04502d04d9509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -129,6 +129,16 @@ object UnsupportedOperationChecker { !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } + def checkUnsupportedExpressions(implicit operator: LogicalPlan): Unit = { + val unsupportedExprs = operator.expressions.flatMap(_.collect { + case m: MonotonicallyIncreasingID => m + }).distinct + if (unsupportedExprs.nonEmpty) { + throwError("Expression(s): " + unsupportedExprs.map(_.sql).mkString(", ") + + " is not supported with streaming DataFrames/Datasets") + } + } + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan @@ -323,6 +333,9 @@ object UnsupportedOperationChecker { case _ => } + + // Check if there are unsupported expressions in streaming query plan. + checkUnsupportedExpressions(subPlan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index e5057c451d5b8..60d1351fda264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} @@ -614,6 +614,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testOutputMode(Update, shouldSupportAggregation = true, shouldSupportNonAggregation = true) testOutputMode(Complete, shouldSupportAggregation = true, shouldSupportNonAggregation = false) + // Unsupported expressions in streaming plan + assertNotSupportedInStreamingPlan( + "MonotonicallyIncreasingID", + streamRelation.select(MonotonicallyIncreasingID()), + outputMode = Append, + expectedMsgs = Seq("monotonically_increasing_id")) + + /* ======================================================================================= TESTING FUNCTIONS From 2030f19511f656e9534f3fd692e622e45f9a074e Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Fri, 6 Oct 2017 20:43:53 -0700 Subject: [PATCH 025/712] [SPARK-21549][CORE] Respect OutputFormats with no output directory provided ## What changes were proposed in this pull request? Fix for https://issues.apache.org/jira/browse/SPARK-21549 JIRA issue. Since version 2.2 Spark does not respect OutputFormat with no output paths provided. The examples of such formats are [Cassandra OutputFormat](https://github.com/finn-no/cassandra-hadoop/blob/08dfa3a7ac727bb87269f27a1c82ece54e3f67e6/src/main/java/org/apache/cassandra/hadoop2/AbstractColumnFamilyOutputFormat.java), [Aerospike OutputFormat](https://github.com/aerospike/aerospike-hadoop/blob/master/mapreduce/src/main/java/com/aerospike/hadoop/mapreduce/AerospikeOutputFormat.java), etc. which do not have an ability to rollback the results written to an external systems on job failure. Provided output directory is required by Spark to allows files to be committed to an absolute output location, that is not the case for output formats which write data to external systems. This pull request prevents accessing `absPathStagingDir` method that causes the error described in SPARK-21549 unless there are files to rename in `addedAbsPathFiles`. ## How was this patch tested? Unit tests Author: Sergey Zhemzhitsky Closes #19294 from szhem/SPARK-21549-abs-output-commits. --- .../io/HadoopMapReduceCommitProtocol.scala | 28 ++++++++++++---- .../spark/rdd/PairRDDFunctionsSuite.scala | 33 ++++++++++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index b1d07ab2c9199..a7e6859ef6b64 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -35,6 +35,9 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * (from the newer mapreduce API, not the old mapred API). * * Unlike Hadoop's OutputCommitter, this implementation is serializable. + * + * @param jobId the job's or stage's id + * @param path the job's output path, or null if committer acts as a noop */ class HadoopMapReduceCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable with Logging { @@ -57,6 +60,15 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + /** + * Checks whether there are files to be committed to an absolute output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether the output path is specified. Output path may not be required + * for committers not writing to distributed file systems. + */ + private def hasAbsPathFiles: Boolean = path != null + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -130,17 +142,21 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - for ((src, dst) <- filesToMove) { - fs.rename(new Path(src), new Path(dst)) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } - fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + if (hasAbsPathFiles) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } } override def setupTask(taskContext: TaskAttemptContext): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 44dd955ce8690..07579c5098014 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistr import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +import org.apache.hadoop.mapreduce.{Job => NewJob, JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable @@ -568,6 +568,37 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } + test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with + // java.lang.IllegalArgumentException: Can not create a Path from a null string + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } + + test("saveAsHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val conf = new JobConf() + conf.setOutputKeyClass(classOf[Integer]) + conf.setOutputValueClass(classOf[Integer]) + conf.setOutputFormat(classOf[FakeOutputFormat]) + conf.setOutputCommitter(classOf[FakeOutputCommitter]) + + FakeOutputCommitter.ran = false + pairs.saveAsHadoopDataset(conf) + + assert(FakeOutputCommitter.ran, "OutputCommitter was never called") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) From 5eacc3bfa9b9c1435ce04222ac7f943b5f930cf4 Mon Sep 17 00:00:00 2001 From: Kento NOZAWA Date: Sat, 7 Oct 2017 08:30:48 +0100 Subject: [PATCH 026/712] [SPARK-22156][MLLIB] Fix update equation of learning rate in Word2Vec.scala ## What changes were proposed in this pull request? Current equation of learning rate is incorrect when `numIterations` > `1`. This PR is based on [original C code](https://github.com/tmikolov/word2vec/blob/master/word2vec.c#L393). cc: mengxr ## How was this patch tested? manual tests I modified [this example code](https://spark.apache.org/docs/2.1.1/mllib-feature-extraction.html#example). ### `numIteration=1` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 2 0.175856813788414 0 0.10971353203058243 4 0.09818313270807266 3 0.012947646901011467 9 -0.09881238639354706 ``` ### `numIteration=5` #### Code ```scala import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) val word2vec = new Word2Vec() word2vec.setNumIterations(5) val model = word2vec.fit(input) val synonyms = model.findSynonyms("1", 5) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } ``` #### Result ``` 0 0.9898583889007568 2 0.9808019399642944 4 0.9794934391975403 3 0.9506527781486511 9 -0.9065656661987305 ``` Author: Kento NOZAWA Closes #19372 from nzw0301/master. --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 6f96813497b62..b8c306d86bace 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) + val totalWordsCounts = numIterations * trainWordsCount + 1 var alpha = learningRate for (k <- 1 to numIterations) { val bcSyn0Global = sc.broadcast(syn0Global) val bcSyn1Global = sc.broadcast(syn1Global) + val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) @@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging { var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - // TODO: discount by iteration? - alpha = - learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + alpha = learningRate * + (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) / + totalWordsCounts) if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " + + s"alpha = $alpha") } wc += sentence.length var pos = 0 From c998a2ae0ea019dfb9b39cef6ddfac07c496e083 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Sun, 8 Oct 2017 12:58:39 +0100 Subject: [PATCH 027/712] [SPARK-22147][CORE] Removed redundant allocations from BlockId ## What changes were proposed in this pull request? Prior to this commit BlockId.hashCode and BlockId.equals were defined in terms of BlockId.name. This allowed the subclasses to be concise and enforced BlockId.name as a single unique identifier for a block. All subclasses override BlockId.name with an expression involving an allocation of StringBuilder and ultimatelly String. This is suboptimal since it induced unnecessary GC pressure on the dirver, see BlockManagerMasterEndpoint. The commit removes the definition of hashCode and equals from the base class. No other change is necessary since all subclasses are in fact case classes and therefore have auto-generated hashCode and equals. No change of behaviour is expected. Sidenote: you might be wondering, why did the subclasses use the base implementation and the auto-generated one? Apparently, this behaviour is documented in the spec. See this SO answer for details https://stackoverflow.com/a/44990210/262432. ## How was this patch tested? BlockIdSuite Author: Sergei Lebedev Closes #19369 from superbobry/blockid-equals-hashcode. --- .../netty/NettyBlockTransferService.scala | 2 +- .../org/apache/spark/storage/BlockId.scala | 5 -- .../org/apache/spark/storage/DiskStore.scala | 8 +-- .../BlockManagerReplicationSuite.scala | 49 ------------------- 4 files changed, 5 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ac4d85004bad1..6a29e18bf3cbb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -151,7 +151,7 @@ private[spark] class NettyBlockTransferService( // Convert or copy nio buffer into array in order to serialize it. val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, new RpcResponseCallback { override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..a441baed2800e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -41,11 +41,6 @@ sealed abstract class BlockId { def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name - override def hashCode: Int = name.hashCode - override def equals(other: Any): Boolean = other match { - case o: BlockId => getClass == o.getClass && name.equals(o.name) - case _ => false - } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 3579acf8d83d9..97abd92d4b70f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -47,9 +47,9 @@ private[spark] class DiskStore( private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString) - private val blockSizes = new ConcurrentHashMap[String, Long]() + private val blockSizes = new ConcurrentHashMap[BlockId, Long]() - def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) + def getSize(blockId: BlockId): Long = blockSizes.get(blockId) /** * Invokes the provided callback function to write the specific block. @@ -67,7 +67,7 @@ private[spark] class DiskStore( var threwException: Boolean = true try { writeFunc(out) - blockSizes.put(blockId.name, out.getCount) + blockSizes.put(blockId, out.getCount) threwException = false } finally { try { @@ -113,7 +113,7 @@ private[spark] class DiskStore( } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId.name) + blockSizes.remove(blockId) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index dd61dcd11bcda..c2101ba828553 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -198,55 +198,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } - test("block replication - deterministic node selection") { - val blockSize = 1000 - val storeSize = 10000 - val stores = (1 to 5).map { - i => makeBlockManager(storeSize, s"store$i") - } - val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 - val storageLevel3x = StorageLevel(true, true, false, true, 3) - val storageLevel4x = StorageLevel(true, true, false, true, 4) - - def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { - stores.head.putSingle(blockId, new Array[Byte](blockSize), level) - val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet - stores.foreach { _.removeBlock(blockId) } - master.removeBlock(blockId) - locations - } - - // Test if two attempts to 2x replication returns same set of locations - val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) - assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, - "Inserting a 2x replicated block second time gave different locations from the first") - - // Test if two attempts to 3x replication returns same set of locations - val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) - assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, - "Inserting a 3x replicated block second time gave different locations from the first") - - // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication - val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) - assert( - a2Locs2x.subsetOf(a2Locs3x), - "Inserting a with 2x replication gave locations that are not a subset of locations" + - s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" - ) - - // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication - val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) - assert( - a2Locs3x.subsetOf(a2Locs4x), - "Inserting a with 4x replication gave locations that are not a superset of locations " + - s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" - ) - - // Test if 3x replication of two different blocks gives two different sets of locations - val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) - assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") - } - test("block replication - replication failures") { /* Create a system of three block managers / stores. One of them (say, failableStore) From fe7b219ae3e8a045655a836cbb77219036ec5740 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 9 Oct 2017 14:16:25 +0800 Subject: [PATCH 028/712] [SPARK-22074][CORE] Task killed by other attempt task should not be resubmitted ## What changes were proposed in this pull request? As the detail scenario described in [SPARK-22074](https://issues.apache.org/jira/browse/SPARK-22074), unnecessary resubmitted may cause stage hanging in currently release versions. This patch add a new var in TaskInfo to mark this task killed by other attempt or not. ## How was this patch tested? Add a new UT `[SPARK-22074] Task killed by other attempt task should not be resubmitted` in TaskSetManagerSuite, this UT recreate the scenario in JIRA description, it failed without the changes in this PR and passed conversely. Author: Yuanjian Li Closes #19287 from xuanyuanking/SPARK-22074. --- .../spark/scheduler/TaskSetManager.scala | 8 +- .../org/apache/spark/scheduler/FakeTask.scala | 20 +++- .../spark/scheduler/TaskSetManagerSuite.scala | 107 ++++++++++++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3bdede6743d1b..de4711f461df2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -83,6 +83,11 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) + // Set the coresponding index of Boolean var when the task killed by other attempt tasks, + // this happened while we set the `spark.speculation` to true. The task killed by others + // should not resubmit while executor lost. + private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -729,6 +734,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") + killedByOtherAttempt(index) = true sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -915,7 +921,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index)) { + if (successful(index) && !killedByOtherAttempt(index)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..109d4a0a870b8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.executor.TaskMetrics class FakeTask( @@ -58,4 +57,21 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createShuffleMapTaskSet( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { + override def index: Int = i + }, prefLocs(i), new Properties, + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + } + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 5c712bd6a545b..2ce81ae27daf6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -744,6 +744,113 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(resubmittedTasks === 0) } + + test("[SPARK-22074] Task killed by other attempt task should not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 + manager.executorLost("exec3", "host3", SlaveLost()) + // Check the resubmittedTasks + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( From 98057583dd2787c0e396c2658c7dd76412f86936 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 9 Oct 2017 10:42:33 +0200 Subject: [PATCH 029/712] [SPARK-20679][ML] Support recommending for a subset of users/items in ALSModel This PR adds methods `recommendForUserSubset` and `recommendForItemSubset` to `ALSModel`. These allow recommending for a specified set of user / item ids rather than for every user / item (as in the `recommendForAllX` methods). The subset methods take a `DataFrame` as input, containing ids in the column specified by the param `userCol` or `itemCol`. The model will generate recommendations for each _unique_ id in this input dataframe. ## How was this patch tested? New unit tests in `ALSSuite` and Python doctests in `ALS`. Ran updated examples locally. Author: Nick Pentreath Closes #18748 from MLnick/als-recommend-df. --- .../spark/examples/ml/JavaALSExample.java | 9 ++ examples/src/main/python/ml/als_example.py | 9 ++ .../apache/spark/examples/ml/ALSExample.scala | 9 ++ .../apache/spark/ml/recommendation/ALS.scala | 48 +++++++++ .../spark/ml/recommendation/ALSSuite.scala | 100 ++++++++++++++++-- python/pyspark/ml/recommendation.py | 38 +++++++ 6 files changed, 205 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index fe4d6bc83f04a..27052be87b82e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -118,9 +118,18 @@ public static void main(String[] args) { Dataset userRecs = model.recommendForAllUsers(10); // Generate top 10 user recommendations for each movie Dataset movieRecs = model.recommendForAllItems(10); + + // Generate top 10 movie recommendations for a specified set of users + Dataset users = ratings.select(als.getUserCol()).distinct().limit(3); + Dataset userSubsetRecs = model.recommendForUserSubset(users, 10); + // Generate top 10 user recommendations for a specified set of movies + Dataset movies = ratings.select(als.getItemCol()).distinct().limit(3); + Dataset movieSubSetRecs = model.recommendForItemSubset(movies, 10); // $example off$ userRecs.show(); movieRecs.show(); + userSubsetRecs.show(); + movieSubSetRecs.show(); spark.stop(); } diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1672d552eb1d5..8b7ec9c439f9f 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -60,8 +60,17 @@ userRecs = model.recommendForAllUsers(10) # Generate top 10 user recommendations for each movie movieRecs = model.recommendForAllItems(10) + + # Generate top 10 movie recommendations for a specified set of users + users = ratings.select(als.getUserCol()).distinct().limit(3) + userSubsetRecs = model.recommendForUserSubset(users, 10) + # Generate top 10 user recommendations for a specified set of movies + movies = ratings.select(als.getItemCol()).distinct().limit(3) + movieSubSetRecs = model.recommendForItemSubset(movies, 10) # $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 07b15dfa178f7..8091838a2301e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -80,9 +80,18 @@ object ALSExample { val userRecs = model.recommendForAllUsers(10) // Generate top 10 user recommendations for each movie val movieRecs = model.recommendForAllItems(10) + + // Generate top 10 movie recommendations for a specified set of users + val users = ratings.select(als.getUserCol).distinct().limit(3) + val userSubsetRecs = model.recommendForUserSubset(users, 10) + // Generate top 10 user recommendations for a specified set of movies + val movies = ratings.select(als.getItemCol).distinct().limit(3) + val movieSubSetRecs = model.recommendForItemSubset(movies, 10) // $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3d5fd1794de23..a8843661c873b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -344,6 +344,21 @@ class ALSModel private[ml] ( recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } + /** + * Returns top `numItems` items recommended for each user id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of user ids. The column name must match `userCol`. + * @param numItems max number of recommendations for each user. + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, userFactors, $(userCol)) + recommendForAll(srcFactorSubset, itemFactors, $(userCol), $(itemCol), numItems) + } + /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item @@ -355,6 +370,39 @@ class ALSModel private[ml] ( recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } + /** + * Returns top `numUsers` users recommended for each item id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of item ids. The column name must match `itemCol`. + * @param numUsers max number of recommendations for each item. + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol)) + recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Returns a subset of a factor DataFrame limited to only those unique ids contained + * in the input dataset. + * @param dataset input Dataset containing id column to user to filter factors. + * @param factors factor DataFrame to filter. + * @param column column name containing the ids in the input dataset. + * @return DataFrame containing factors only for those ids present in both the input dataset and + * the factor DataFrame. + */ + private def getSourceFactorSubset( + dataset: Dataset[_], + factors: DataFrame, + column: String): DataFrame = { + factors + .join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi") + .select(factors("id"), factors("features")) + } + /** * Makes recommendations for all users (or items). * diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ac7319110159b..addcd21d50aac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -723,9 +723,9 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), - 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Seq((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) Seq(2, 4, 6).foreach { k => @@ -743,10 +743,10 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 3 -> Array((0, 54f), (2, 51f), (1, 39f)), - 4 -> Array((0, 44f), (2, 30f), (1, 26f)), - 5 -> Array((2, 45f), (0, 42f), (1, 33f)), - 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 4 -> Seq((0, 44f), (2, 30f), (1, 26f)), + 5 -> Seq((2, 45f), (0, 42f), (1, 33f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) ) Seq(2, 3, 4).foreach { k => @@ -759,9 +759,93 @@ class ALSSuite } } + test("recommendForUserSubset with k <, = and > num_items") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numItems = model.itemFactors.count + val expected = Map( + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + val userSubset = expected.keys.toSeq.toDF("user") + val numUsersSubset = userSubset.count + + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForUserSubset(userSubset, k) + assert(topItems.count() == numUsersSubset) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } + } + + test("recommendForItemSubset with k <, = and > num_users") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numUsers = model.userFactors.count + val expected = Map( + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) + ) + val itemSubset = expected.keys.toSeq.toDF("item") + val numItemsSubset = itemSubset.count + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = model.recommendForItemSubset(itemSubset, k) + assert(topUsers.count() == numItemsSubset) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } + } + + test("subset recommendations eliminate duplicate ids, returns same results as unique ids") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val users = Seq(0, 1).toDF("user") + val dupUsers = Seq(0, 1, 0, 1).toDF("user") + val singleUserRecs = model.recommendForUserSubset(users, k) + val dupUserRecs = model.recommendForUserSubset(dupUsers, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleUserRecs.count == dupUserRecs.size) + checkRecommendations(singleUserRecs, dupUserRecs, "item") + + val items = Seq(3, 4, 5).toDF("item") + val dupItems = Seq(3, 4, 5, 4, 5).toDF("item") + val singleItemRecs = model.recommendForItemSubset(items, k) + val dupItemRecs = model.recommendForItemSubset(dupItems, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleItemRecs.count == dupItemRecs.size) + checkRecommendations(singleItemRecs, dupItemRecs, "user") + } + + test("subset recommendations on full input dataset equivalent to recommendForAll") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val userSubset = model.userFactors.withColumnRenamed("id", "user").drop("features") + val userSubsetRecs = model.recommendForUserSubset(userSubset, k) + val allUserRecs = model.recommendForAllUsers(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(userSubsetRecs, allUserRecs, "item") + + val itemSubset = model.itemFactors.withColumnRenamed("id", "item").drop("features") + val itemSubsetRecs = model.recommendForItemSubset(itemSubset, k) + val allItemRecs = model.recommendForAllItems(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(itemSubsetRecs, allItemRecs, "user") + } + private def checkRecommendations( topK: DataFrame, - expected: Map[Int, Array[(Int, Float)]], + expected: Map[Int, Seq[(Int, Float)]], dstColName: String): Unit = { val spark = this.spark import spark.implicits._ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index bcfb36880eb02..e8bcbe4cd34cb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -90,6 +90,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> item_recs.where(item_recs.item == 2)\ .select("recommendations.user", "recommendations.rating").collect() [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] + >>> user_subset = df.where(df.user == 2) + >>> user_subset_recs = model.recommendForUserSubset(user_subset, 3) + >>> user_subset_recs.select("recommendations.item", "recommendations.rating").first() + Row(item=[2, 1, 0], rating=[4.901..., 1.056..., -1.501...]) + >>> item_subset = df.where(df.item == 0) + >>> item_subset_recs = model.recommendForItemSubset(item_subset, 3) + >>> item_subset_recs.select("recommendations.user", "recommendations.rating").first() + Row(user=[0, 1, 2], rating=[3.910..., 2.625..., -1.501...]) >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -414,6 +422,36 @@ def recommendForAllItems(self, numUsers): """ return self._call_java("recommendForAllItems", numUsers) + @since("2.3.0") + def recommendForUserSubset(self, dataset, numItems): + """ + Returns top `numItems` items recommended for each user id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of user ids. The column name must match + `userCol`. + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForUserSubset", dataset, numItems) + + @since("2.3.0") + def recommendForItemSubset(self, dataset, numUsers): + """ + Returns top `numUsers` users recommended for each item id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of item ids. The column name must match + `itemCol`. + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForItemSubset", dataset, numUsers) + if __name__ == "__main__": import doctest From f31e11404d6d5ee28b574c242ecbee94f35e9370 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 9 Oct 2017 12:53:10 -0700 Subject: [PATCH 030/712] [SPARK-21568][CORE] ConsoleProgressBar should only be enabled in shells ## What changes were proposed in this pull request? This PR disables console progress bar feature in non-shell environment by overriding the configuration. ## How was this patch tested? Manual. Run the following examples with and without `spark.ui.showConsoleProgress` in order to see progress bar on master branch and this PR. **Scala Shell** ```scala spark.range(1000000000).map(_ + 1).count ``` **PySpark** ```python spark.range(10000000).rdd.map(lambda x: len(x)).count() ``` **Spark Submit** ```python from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession.builder.getOrCreate() spark.range(2000000).rdd.map(lambda row: len(row)).count() spark.stop() ``` Author: Dongjoon Hyun Closes #19061 from dongjoon-hyun/SPARK-21568. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++++ .../org/apache/spark/internal/config/package.scala | 5 +++++ .../org/apache/spark/deploy/SparkSubmitSuite.scala | 12 ++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cec61d85ccf38..b3cd03c0cfbe1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -434,7 +434,7 @@ class SparkContext(config: SparkConf) extends Logging { _statusTracker = new SparkStatusTracker(this) _progressBar = - if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { Some(new ConsoleProgressBar(this)) } else { None diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 286a4379d2040..135bbe93bf28e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -598,6 +598,11 @@ object SparkSubmit extends CommandLineUtils with Logging { } } + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + } + // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d85b6a0200b8d..5278e5e0fb270 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -203,6 +203,11 @@ package object config { private[spark] val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + private[spark] val UI_SHOW_CONSOLE_PROGRESS = ConfigBuilder("spark.ui.showConsoleProgress") + .doc("When true, show the progress bar in the console.") + .booleanConf + .createWithDefault(false) + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .booleanConf .createWithDefault(false) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ad801bf8519a6..b06f2e26a4a7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -399,6 +399,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") } + test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + + val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( From a74ec6d7bbfe185ba995dcb02d69e90a089c293e Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 9 Oct 2017 12:56:37 -0700 Subject: [PATCH 031/712] [SPARK-22218] spark shuffle services fails to update secret on app re-attempts This patch fixes application re-attempts when running spark on yarn using the external shuffle service with security on. Currently executors will fail to launch on any application re-attempt when launched on a nodemanager that had an executor from the first attempt. The reason for this is because we aren't updating the secret key after the first application attempt. The fix here is to just remove the containskey check to see if it already exists. In this way, we always add it and make sure its the most recent secret. Similarly remove the check for containsKey on the remove since its just adding extra check that isn't really needed. Note this worked before spark 2.2 because the check used to be contains (which was looking for the value) rather then containsKey, so that never matched and it was just always adding the new secret. Patch was tested on a 10 node cluster as well as added the unit test. The test ran was a wordcount where the output directory already existed. With the bug present the application attempt failed with max number of executor Failures which were all saslExceptions. With the fix present the application re-attempts fail with directory already exists or when you remove the directory between attempts the re-attemps succeed. Author: Thomas Graves Closes #19450 from tgravescs/SPARK-22218. --- .../network/sasl/ShuffleSecretManager.java | 19 +++---- .../sasl/ShuffleSecretManagerSuite.java | 55 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index d2d008f8a3d35..7253101f41df6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -47,12 +47,11 @@ public ShuffleSecretManager() { * fetching shuffle files written by other executors in this application. */ public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.put(appId, shuffleSecret); - logger.info("Registered shuffle secret for application {}", appId); - } else { - logger.debug("Application {} already registered", appId); - } + // Always put the new secret information to make sure it's the most up to date. + // Otherwise we have to specifically look at the application attempt in addition + // to the applicationId since the secrets change between application attempts on yarn. + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); } /** @@ -67,12 +66,8 @@ public void registerApp(String appId, ByteBuffer shuffleSecret) { * This is called when the application terminates. */ public void unregisterApp(String appId) { - if (shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.remove(appId); - logger.info("Unregistered shuffle secret for application {}", appId); - } else { - logger.warn("Attempted to unregister application {} when it is not registered", appId); - } + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java new file mode 100644 index 0000000000000..46c4c33865eea --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.nio.ByteBuffer; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class ShuffleSecretManagerSuite { + static String app1 = "app1"; + static String app2 = "app2"; + static String pw1 = "password1"; + static String pw2 = "password2"; + static String pw1update = "password1update"; + static String pw2update = "password2update"; + + @Test + public void testMultipleRegisters() { + ShuffleSecretManager secretManager = new ShuffleSecretManager(); + secretManager.registerApp(app1, pw1); + assertEquals(pw1, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2.getBytes())); + assertEquals(pw2, secretManager.getSecretKey(app2)); + + // now update the password for the apps and make sure it takes affect + secretManager.registerApp(app1, pw1update); + assertEquals(pw1update, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2update.getBytes())); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app1); + assertNull(secretManager.getSecretKey(app1)); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app2); + assertNull(secretManager.getSecretKey(app2)); + assertNull(secretManager.getSecretKey(app1)); + } +} From b650ee0265477ada68220cbf286fa79906608ef5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Oct 2017 13:55:55 -0700 Subject: [PATCH 032/712] [INFRA] Close stale PRs. Closes #19423 Closes #19455 From dadd13f365aad0d9228cd8b8e6d57ad32175b155 Mon Sep 17 00:00:00 2001 From: Pavel Sakun Date: Mon, 9 Oct 2017 23:00:04 +0100 Subject: [PATCH 033/712] [SPARK] Misleading error message for missing --proxy-user value Fix misleading error message when argument is expected. ## What changes were proposed in this pull request? Change message to be accurate. ## How was this patch tested? Messaging change, was tested manually. Author: Pavel Sakun Closes #19457 from pavel-sakun/patch-1. --- .../src/main/java/org/apache/spark/launcher/SparkLauncher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 718a368a8e731..75b8ef5ca5ef4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -625,7 +625,7 @@ private static class ArgumentValidator extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { if (value == null && hasValue) { - throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + throw new IllegalArgumentException(String.format("'%s' expects a value.", opt)); } return true; } From 155ab6347ec7be06c937372a51e8013fdd371d93 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 9 Oct 2017 15:22:41 -0700 Subject: [PATCH 034/712] [SPARK-22170][SQL] Reduce memory consumption in broadcast joins. ## What changes were proposed in this pull request? This updates the broadcast join code path to lazily decompress pages and iterate through UnsafeRows to prevent all rows from being held in memory while the broadcast table is being built. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #19394 from rdblue/broadcast-driver-memory. --- .../plans/physical/broadcastMode.scala | 6 ++++ .../spark/sql/execution/SparkPlan.scala | 19 ++++++++---- .../exchange/BroadcastExchangeExec.scala | 29 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 13 ++++++++- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 3 +- 6 files changed, 54 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 2ab46dc8330aa..9fac95aed8f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode } @@ -36,5 +38,9 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): Array[InternalRow] = rows.toArray + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b263f100e6068..2ffd948f984bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -223,7 +223,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also * compressed. */ - private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = { execute().mapPartitionsInternal { iter => var count = 0 val buffer = new Array[Byte](4 << 10) // 4K @@ -239,7 +239,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() - Iterator(bos.toByteArray) + Iterator((count, bos.toByteArray)) } } @@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val byteArrayRdd = getByteArrayRdd() val results = ArrayBuffer[InternalRow]() - byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes).foreach(results.+=) + byteArrayRdd.collect().foreach { countAndBytes => + decodeUnsafeRows(countAndBytes._2).foreach(results.+=) } results.toArray } + private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { + val countsAndBytes = getByteArrayRdd().collect() + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + } + /** * Runs this query returning the result as an iterator of InternalRow. * * @note Triggers multiple jobs (one for each partition). */ def executeToIterator(): Iterator[InternalRow] = { - getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows) } /** @@ -307,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = getByteArrayRdd(n) + val childRDD = getByteArrayRdd(n).map(_._2) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9c859e41f8762..880e18c6808b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -72,26 +72,39 @@ case class BroadcastExchangeExec( SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - if (input.length >= 512000000) { + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") } + val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + + relation.getClass.getName) + } + longMetric("dataSize") += dataSize if (dataSize >= (8L << 30)) { throw new SparkException( s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") } - // Construct and broadcast the relation. - val relation = mode.transform(input) val beforeBroadcast = System.nanoTime() longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + // Broadcast the relation val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f8058b2f7813b..b2dcbe5aa9877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -866,7 +866,18 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalized.key, rows.length) + transform(rows.iterator, Some(rows.length)) + } + + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): HashedRelation = { + sizeHint match { + case Some(numRows) => + HashedRelation(rows, canonicalized.key, numRows.toInt) + case None => + HashedRelation(rows, canonicalized.key) + } } override lazy val canonicalized: HashedRelationBroadcastMode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 2c1e5db5fd9bb..cee85ec8af04d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -58,7 +58,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { // If we only sample one point, the range boundaries will be pretty bad and the // chi-sq value would be very high. - assert(computeChiSquareTest() > 1000) + assert(computeChiSquareTest() > 300) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0dc612ef735fa..58a194b8af62b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -227,8 +227,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + "number of output rows" -> 2L)))) ) } From 71c2b81aa0e0db70013821f5512df1fbd8e59445 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 9 Oct 2017 16:34:39 -0700 Subject: [PATCH 035/712] [SPARK-22230] Swap per-row order in state store restore. ## What changes were proposed in this pull request? In state store restore, for each row, put the saved state before the row in the iterator instead of after. This fixes an issue where agg(last('attr)) will forever return the last value of 'attr from the first microbatch. ## How was this patch tested? new unit test Author: Jose Torres Closes #19461 from joseph-torres/SPARK-22230. --- .../execution/streaming/statefulOperators.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index fb960fbdde8b3..0d85542928ee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -225,7 +225,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: Option(savedState).toSeq + Option(savedState).toSeq :+ row } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 995cea3b37d4f..fe7efa69f7e31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -520,6 +520,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } + test("SPARK-22230: last should change with new batches") { + val input = MemoryStream[Int] + + val aggregated = input.toDF().agg(last('value)) + testStream(aggregated, OutputMode.Complete())( + AddData(input, 1, 2, 3), + CheckLastBatch(3), + AddData(input, 4, 5, 6), + CheckLastBatch(6), + AddData(input), + CheckLastBatch(6), + AddData(input, 0), + CheckLastBatch(0) + ) + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From bebd2e1ce10a460555f75cda75df33f39a783469 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 9 Oct 2017 21:34:37 -0700 Subject: [PATCH 036/712] [SPARK-22222][CORE] Fix the ARRAY_MAX in BufferHolder and add a test ## What changes were proposed in this pull request? We should not break the assumption that the length of the allocated byte array is word rounded: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java#L170 So we want to use `Integer.MAX_VALUE - 15` instead of `Integer.MAX_VALUE - 8` as the upper bound of an allocated byte array. cc: srowen gatorsmile ## How was this patch tested? Since the Spark unit test JVM has less than 1GB heap, here we run the test code as a submit job, so it can run on a JVM has 4GB memory. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #19460 from liufengdb/fix_array_max. --- .../spark/unsafe/array/ByteArrayMethods.java | 7 ++ .../unsafe/map/HashMapGrowthStrategy.java | 6 +- .../collection/PartitionedPairBuffer.scala | 6 +- .../spark/deploy/SparkSubmitSuite.scala | 52 +++++++------ .../expressions/codegen/BufferHolder.java | 7 +- .../BufferHolderSparkSubmitSutie.scala | 78 +++++++++++++++++++ .../vectorized/WritableColumnVector.java | 3 +- 7 files changed, 124 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 9c551ab19e9aa..f121b1cd745b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,13 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat smaller. + // Be conservative and lower the cap a little. + // Refer to "http://hg.openjdk.java.net/jdk8/jdk8/jdk/file/tip/src/share/classes/java/util/ArrayList.java#l229" + // This value is word rounded. Use this value if the allocated byte arrays are used to store other + // types rather than bytes. + public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index b8c2294c7b7ab..ee6d9f75ac5aa 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -17,6 +17,8 @@ package org.apache.spark.unsafe.map; +import org.apache.spark.unsafe.array.ByteArrayMethods; + /** * Interface that defines how we can grow the size of a hash map when it is over a threshold. */ @@ -31,9 +33,7 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; @Override public int nextCapacity(int currentCapacity) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index b755e5da51684..e17a9de97e335 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import java.util.Comparator +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** @@ -96,7 +98,5 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } private object PartitionedPairBuffer { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val MAXIMUM_CAPACITY: Int = (Int.MaxValue - 8) / 2 + val MAXIMUM_CAPACITY: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 2 } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b06f2e26a4a7a..b52da4c0c8bc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -100,6 +100,8 @@ class SparkSubmitSuite with TimeLimits with TestPrematureExit { + import SparkSubmitSuite._ + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -974,30 +976,6 @@ class SparkSubmitSuite } } - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val sparkSubmitFile = if (Utils.isWindows) { - new File("..\\bin\\spark-submit.cmd") - } else { - new File("../bin/spark-submit") - } - val process = Utils.executeCommand( - Seq(sparkSubmitFile.getCanonicalPath) ++ args, - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - - try { - val exitCode = failAfter(60 seconds) { process.waitFor() } - if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") - } - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } - private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() @@ -1020,6 +998,32 @@ class SparkSubmitSuite } } +object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File(s"$root\\bin\\spark-submit.cmd") + } else { + new File(s"$root/bin/spark-submit") + } + val process = Utils.executeCommand( + Seq(sparkSubmitFile.getCanonicalPath) ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + object JarCreationTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 971d19973f067..259976118c12f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and @@ -36,9 +37,7 @@ */ public class BufferHolder { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; @@ -51,7 +50,7 @@ public BufferHolder(UnsafeRow row) { public BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); - if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) { + if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( "Cannot create BufferHolder for input UnsafeRow because there are " + "too many fields (number of fields: " + row.numFields() + ")"); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala new file mode 100644 index 0000000000000..1167d2f3f3891 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Timeouts + +import org.apache.spark.{SparkFunSuite, TestUtils} +import org.apache.spark.deploy.SparkSubmitSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.ResetSystemProperties + +// A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark +// unit tests JVM, the actually test code is running as a submit job. +class BufferHolderSparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts { + + test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + val argsForSparkSubmit = Seq( + "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), + "--name", "SPARK-22222", + "--master", "local-cluster[2,1,1024]", + "--driver-memory", "4g", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.driver.extraJavaOptions=-ea", + unusedJar.toString) + SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") + } +} + +object BufferHolderSparkSubmitSuite { + + def main(args: Array[String]): Unit = { + + val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val holder = new BufferHolder(new UnsafeRow(1000)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE)) + } + + private def roundToWord(len: Int): Int = { + ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index da72954ddc448..d3a14b9d8bd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** @@ -595,7 +596,7 @@ public final int appendStruct(boolean isNull) { * Upper limit for the maximum capacity for this column. */ @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE - 8; + protected int MAX_CAPACITY = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. From af8a34c787dc3d68f5148a7d9975b52650bb7729 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 9 Oct 2017 22:35:34 -0700 Subject: [PATCH 037/712] [SPARK-22159][SQL][FOLLOW-UP] Make config names consistently end with "enabled". ## What changes were proposed in this pull request? This is a follow-up of #19384. In the previous pr, only definitions of the config names were modified, but we also need to modify the names in runtime or tests specified as string literal. ## How was this patch tested? Existing tests but modified the config names. Author: Takuya UESHIN Closes #19462 from ueshin/issues/SPARK-22159/fup1. --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/tests.py | 6 +++--- .../aggregate/HashAggregateExec.scala | 2 +- .../spark/sql/AggregateHashMapSuite.scala | 12 +++++------ .../benchmark/AggregateBenchmark.scala | 20 +++++++++---------- .../execution/AggregationQuerySuite.scala | 2 +- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7ce9a83a616d..fe69e588fe098 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1878,7 +1878,7 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() @@ -1889,7 +1889,7 @@ def toPandas(self): return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" + "if using spark.sql.execution.arrow.enabled=true" raise ImportError("%s\n%s" % (e.message, msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad2..a59378b5e848a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3088,7 +3088,7 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3120,9 +3120,9 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f424096b330e3..8b573fdcf25e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -539,7 +539,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext) = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { - logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but" + logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 7e61a68025158..938d76c9f0837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.SparkConf class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false", "configuration parameter changed in test body") } } @@ -39,14 +39,14 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") } } @@ -57,7 +57,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed @@ -65,7 +65,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", "configuration parameter changed in test body") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index aca1be01fa3da..a834b7cd2c69f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,14 +107,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -149,14 +149,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -189,14 +189,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -228,14 +228,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -277,14 +277,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f245a79f805a2..ae675149df5e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1015,7 +1015,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" -> + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> From 3b5c2a84bfa311a94c1c0a57f2cb3e421fb05650 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Oct 2017 08:27:45 +0100 Subject: [PATCH 038/712] [SPARK-21770][ML] ProbabilisticClassificationModel fix corner case: normalization of all-zero raw predictions ## What changes were proposed in this pull request? Fix probabilisticClassificationModel corner case: normalization of all-zero raw predictions, throw IllegalArgumentException with description. ## How was this patch tested? Test case added. Author: WeichenXu Closes #19106 from WeichenXu123/SPARK-21770. --- .../ProbabilisticClassifier.scala | 20 ++++++++++--------- .../ProbabilisticClassifierSuite.scala | 18 +++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index ef08134809915..730fcab333e11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel { * Normalize a vector of raw predictions to be a multinomial probability vector, in place. * * The input raw predictions should be nonnegative. - * The output vector sums to 1, unless the input vector is all-0 (in which case the output is - * all-0 too). + * The output vector sums to 1. * * NOTE: This is NOT applicable to all models, only ones which effectively use class * instance counts for raw predictions. + * + * @throws IllegalArgumentException if the input vector is all-0 or including negative values */ def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + v.values.foreach(value => require(value >= 0, + "The input raw predictions should be nonnegative.")) val sum = v.values.sum - if (sum != 0) { - var i = 0 - val size = v.size - while (i < size) { - v.values(i) /= sum - i += 1 - } + require(sum > 0, "Can't normalize the 0-vector.") + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 4ecd5a05365eb..d649ceac949c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) } } + + test("normalizeToProbabilitiesInPlace") { + val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1) + assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3) + + // all-0 input test + val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2) + } + + // negative input test + val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3) + } + } } object ProbabilisticClassifierSuite { From b8a08f25cc64ed3034f3c90790931c30e5b0f236 Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 10 Oct 2017 20:44:33 +0800 Subject: [PATCH 039/712] [SPARK-21506][DOC] The description of "spark.executor.cores" may be not correct ## What changes were proposed in this pull request? The number of cores assigned to each executor is configurable. When this is not explicitly set, multiple executors from the same application may be launched on the same worker too. ## How was this patch tested? N/A Author: liuxian Closes #18711 from 10110346/executorcores. --- .../spark/deploy/client/StandaloneAppClient.scala | 2 +- .../scala/org/apache/spark/deploy/master/Master.scala | 8 +++++++- .../cluster/StandaloneSchedulerBackend.scala | 2 +- docs/configuration.md | 11 ++++------- docs/spark-standalone.md | 8 ++++++++ 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 757c930b84eb2..34ade4ce6f39b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -170,7 +170,7 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, + logInfo("Executor added: %s on %s (%s) with %d core(s)".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e030cac60a8e4..2c78c15773af2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,13 @@ private[deploy] class Master( * The number of cores assigned to each executor is configurable. When this is explicitly set, * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the - * worker by default, in which case only one executor may be launched on each worker. + * worker by default, in which case only one executor per application may be launched on each + * worker during one single schedule iteration. + * Note that when `spark.executor.cores` is not set, we may still launch multiple executors from + * the same application on the same worker. Consider appA and appB both have one executor running + * on worker1, and appA.coresLeft > 0, then appB is finished and release all its cores on worker1, + * thus for the next schedule iteration, appA launches a new executor that grabs all the free + * cores on worker1, therefore we get multiple executors from appA running on worker1. * * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core * at a time). Consider the following example: cluster has 4 workers with 16 cores each. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a4e2a74341283..505c342a889ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -153,7 +153,7 @@ private[spark] class StandaloneSchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } diff --git a/docs/configuration.md b/docs/configuration.md index 6e9fe591b70a3..7a777d3c6fa3d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1015,7 +1015,7 @@ Apart from these, the following properties are also available, and may be useful 0.5 Amount of storage memory immune to eviction, expressed as a fraction of the size of the - region set aside by s​park.memory.fraction. The higher this is, the less + region set aside by spark.memory.fraction. The higher this is, the less working memory may be available to execution and tasks may spill to disk more often. Leaving this at the default value is recommended. For more detail, see this description. @@ -1041,7 +1041,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.useLegacyMode false - ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + Whether to enable the legacy memory management mode used in Spark 1.5 and before. The legacy mode rigidly partitions the heap space into fixed-size regions, potentially leading to excessive spilling if the application was not tuned. The following deprecated memory fraction configurations are not read unless this is enabled: @@ -1115,11 +1115,8 @@ Apart from these, the following properties are also available, and may be useful The number of cores to use on each executor. - In standalone and Mesos coarse-grained modes, setting this - parameter allows an application to run multiple executors on the - same worker, provided that there are enough cores on that - worker. Otherwise, only one executor per application will run on - each worker. + In standalone and Mesos coarse-grained modes, for more detail, see + this description. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1095386c31ab8..f51c5cc38f4de 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -328,6 +328,14 @@ export SPARK_MASTER_OPTS="-Dspark.deploy.defaultCores=" This is useful on shared clusters where users might not have configured a maximum number of cores individually. +# Executors Scheduling + +The number of cores assigned to each executor is configurable. When `spark.executor.cores` is +explicitly set, multiple executors from the same application may be launched on the same worker +if the worker has enough cores and memory. Otherwise, each executor grabs all the cores available +on the worker by default, in which case only one executor per application may be launched on each +worker during one single schedule iteration. + # Monitoring and Logging Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. From 23af2d79ad9a3c83936485ee57513b39193a446b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 10 Oct 2017 20:48:42 +0800 Subject: [PATCH 040/712] [SPARK-20025][CORE] Ignore SPARK_LOCAL* env, while deploying via cluster mode. ## What changes were proposed in this pull request? In a bare metal system with No DNS setup, spark may be configured with SPARK_LOCAL* for IP and host properties. During a driver failover, in cluster deployment mode. SPARK_LOCAL* should be ignored while restarting on another node and should be picked up from target system's local environment. ## How was this patch tested? Distributed deployment against a spark standalone cluster of 6 Workers. Tested by killing JVM's running driver and verified the restarted JVMs have right configurations on them. Author: Prashant Sharma Author: Prashant Sharma Closes #17357 from ScrapCodes/driver-failover-fix. --- core/src/main/scala/org/apache/spark/deploy/Client.scala | 6 +++--- .../apache/spark/deploy/rest/StandaloneRestServer.scala | 4 +++- .../org/apache/spark/deploy/worker/DriverWrapper.scala | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index bf6093236d92b..7acb5c55bb252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -93,19 +93,19 @@ private class ClientEndpoint( driverArgs.cores, driverArgs.supervise, command) - ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + asyncSendToMasterAndForwardReply[SubmitDriverResponse]( RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + asyncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) } } /** * Send the message to master and forward the reply to self asynchronously. */ - private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + private def asyncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { for (masterEndpoint <- masterEndpoints) { masterEndpoint.ask[T](message).onComplete { case Success(v) => self.send(v) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 0164084ab129e..22b65abce611a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -139,7 +139,9 @@ private[rest] class StandaloneSubmitRequestServlet( val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") val superviseDriver = sparkProperties.get("spark.driver.supervise") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. + val environmentVariables = + request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) // Construct driver description val conf = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index c1671192e0c64..b19c9904d5982 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -23,6 +23,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -30,7 +31,7 @@ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, U * Utility object for launching driver programs such that they share fate with the Worker process. * This is used in standalone cluster mode only. */ -object DriverWrapper { +object DriverWrapper extends Logging { def main(args: Array[String]) { args.toList match { /* @@ -41,8 +42,10 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val rpcEnv = RpcEnv.create("Driver", - Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val host: String = Utils.localHostName() + val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt + val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf)) + logInfo(s"Driver address: ${rpcEnv.address}") rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader From 633ffd816d285480bab1f346471135b10ec092bb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 10 Oct 2017 11:01:02 -0700 Subject: [PATCH 041/712] rename the file. --- ...rSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/{BufferHolderSparkSubmitSutie.scala => BufferHolderSparkSubmitSuite.scala} (100%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala similarity index 100% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSutie.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala From 2028e5a82bc3e9a79f9b84f376bdf606b8c9bb0f Mon Sep 17 00:00:00 2001 From: Eyal Farago Date: Tue, 10 Oct 2017 22:49:47 +0200 Subject: [PATCH 042/712] [SPARK-21907][CORE] oom during spill ## What changes were proposed in this pull request? 1. a test reproducing [SPARK-21907](https://issues.apache.org/jira/browse/SPARK-21907) 2. a fix for the root cause of the issue. `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill` calls `org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset` which may trigger another spill, when this happens the `array` member is already de-allocated but still referenced by the code, this causes the nested spill to fail with an NPE in `org.apache.spark.memory.TaskMemoryManager.getPage`. This patch introduces a reproduction in a test case and a fix, the fix simply sets the in-mem sorter's array member to an empty array before actually performing the allocation. This prevents the spilling code from 'touching' the de-allocated array. ## How was this patch tested? introduced a new test case: `org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorterSuite#testOOMDuringSpill`. Author: Eyal Farago Closes #19181 from eyalfa/SPARK-21907__oom_during_spill. --- .../unsafe/sort/UnsafeExternalSorter.java | 4 ++ .../unsafe/sort/UnsafeInMemorySorter.java | 12 ++++- .../sort/UnsafeExternalSorterSuite.java | 33 +++++++++++++ .../sort/UnsafeInMemorySorterSuite.java | 46 +++++++++++++++++++ .../spark/memory/TestMemoryManager.scala | 12 +++-- 5 files changed, 102 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 39eda00dd7efb..e749f7ba87c6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -480,6 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + @VisibleForTesting boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + private static void spillIterator(UnsafeSorterIterator inMemIterator, UnsafeSorterSpillWriter spillWriter) throws IOException { while (inMemIterator.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c14c12664f5ab..869ec908be1fb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -162,7 +162,9 @@ private int getUsableCapacity() { */ public void free() { if (consumer != null) { - consumer.freeArray(array); + if (array != null) { + consumer.freeArray(array); + } array = null; } } @@ -170,6 +172,14 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); + // the call to consumer.allocateArray may trigger a spill + // which in turn access this instance and eventually re-enter this method and try to free the array again. + // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. + // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 5330a688e63e3..6c5451d0fd2a5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -23,6 +23,7 @@ import java.util.LinkedList; import java.util.UUID; +import org.hamcrest.Matchers; import scala.Tuple2$; import org.junit.After; @@ -503,6 +504,38 @@ public void testGetIterator() throws Exception { verifyIntIterator(sorter.getIterator(279), 279, 300); } + @Test + public void testOOMDuringSpill() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + // we assume that given default configuration, + // the size of the data we insert to the sorter (ints) + // and assuming we shouldn't spill before pointers array is exhausted + // (memory manager is not configured to throw at this point) + // - so this loop runs a reasonable number of iterations (<2000). + // test indeed completed within <30ms (on a quad i7 laptop). + for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { + insertNumber(sorter, i); + } + // we expect the next insert to attempt growing the pointerssArray + // first allocation is expected to fail, then a spill is triggered which attempts another allocation + // which also fails and we expect to see this OOM here. + // the original code messed with a released array within the spill code + // and ended up with a failed assertion. + // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + memoryManager.markconsequentOOM(2); + try { + insertNumber(sorter, 1024); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } + // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (OutOfMemoryError oom){ + String oomStackTrace = Utils.exceptionString(oom); + assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + } + } + private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) throws IOException { for (int i = start; i < end; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085aa9a14..1a3e11efe9787 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -35,6 +35,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.isIn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { @@ -139,4 +140,49 @@ public int compare( } assertEquals(dataToSort.length, iterLength); } + + @Test + public void freeAfterOOM() { + final SparkConf sparkConf = new SparkConf(); + sparkConf.set("spark.memory.offHeap.enabled", "false"); + + final TestMemoryManager testMemoryManager = + new TestMemoryManager(sparkConf); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + testMemoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = PrefixComparators.LONG; + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, + recordComparator, prefixComparator, 100, shouldUseRadixSort()); + + testMemoryManager.markExecutionAsOutOfMemoryOnce(); + try { + sorter.reset(); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } catch (OutOfMemoryError oom) { + // as expected + } + // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + sorter.free(); + // simulate a 'back to back' free. + sorter.free(); + } + } diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 5f699df8211de..c26945fa5fa31 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -27,8 +27,8 @@ class TestMemoryManager(conf: SparkConf) numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = { - if (oomOnce) { - oomOnce = false + if (consequentOOM > 0) { + consequentOOM -= 1 0 } else if (available >= numBytes) { available -= numBytes @@ -58,11 +58,15 @@ class TestMemoryManager(conf: SparkConf) override def maxOffHeapStorageMemory: Long = 0L - private var oomOnce = false + private var consequentOOM = 0 private var available = Long.MaxValue def markExecutionAsOutOfMemoryOnce(): Unit = { - oomOnce = true + markconsequentOOM(1) + } + + def markconsequentOOM(n : Int) : Unit = { + consequentOOM += n } def limit(avail: Long): Unit = { From bfc7e1fe1ad5f9777126f2941e29bbe51ea5da7c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 11 Oct 2017 07:32:01 +0900 Subject: [PATCH 043/712] [SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`. Static schema ------------------- ``` schema = df.schema pandas_udf(schema) def normalize(df): df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std() return df df.groupBy('id').apply(normalize) ``` Dynamic schema ----------------------- **This use case is removed from the PR and we will discuss this as a follow up. See discussion https://github.com/apache/spark/pull/18732#pullrequestreview-66583248** Another example to use pd.DataFrame dtypes as output schema of the udf: ``` sample_df = df.filter(df.id == 1).toPandas() def foo(df): ret = # Some transformation on the input pd.DataFrame return ret foo_udf = pandas_udf(foo, foo(sample_df).dtypes) df.groupBy('id').apply(foo_udf) ``` In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`. Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md ## How was this patch tested? * Added GroupbyApplyTest Author: Li Jin Author: Takuya UESHIN Author: Bryan Cutler Closes #18732 from icexelloss/groupby-apply-SPARK-20396. --- python/pyspark/sql/dataframe.py | 6 +- python/pyspark/sql/functions.py | 98 ++++++++--- python/pyspark/sql/group.py | 88 +++++++++- python/pyspark/sql/tests.py | 157 +++++++++++++++++- python/pyspark/sql/types.py | 2 +- python/pyspark/worker.py | 35 ++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../logical/pythonLogicalOperators.scala | 39 +++++ .../spark/sql/RelationalGroupedDataset.scala | 36 +++- .../spark/sql/execution/SparkStrategies.scala | 2 + .../python/ArrowEvalPythonExec.scala | 39 ++++- .../execution/python/ArrowPythonRunner.scala | 15 +- .../execution/python/ExtractPythonUDFs.scala | 8 +- .../python/FlatMapGroupsInPandasExec.scala | 103 ++++++++++++ 14 files changed, 561 insertions(+), 69 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fe69e588fe098..2d596229ced7e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1227,7 +1227,7 @@ def groupBy(self, *cols): """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def rollup(self, *cols): @@ -1248,7 +1248,7 @@ def rollup(self, *cols): """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def cube(self, *cols): @@ -1271,7 +1271,7 @@ def cube(self, *cols): """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.3) def agg(self, *exprs): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b45a59db93679..9bc12c3b7a162 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self._vectorized = vectorized + self.vectorized = vectorized @property def returnType(self): @@ -2090,7 +2090,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self._vectorized) + self._name, wrapped_func, jdt, self.vectorized) return judf def __call__(self, *cols): @@ -2118,8 +2118,10 @@ def wrapper(*args): wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') else self.func.__class__.__module__) + wrapper.func = self.func wrapper.returnType = self.returnType + wrapper.vectorized = self.vectorized return wrapper @@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized): def _udf(f, returnType=StringType(), vectorized=vectorized): if vectorized: import inspect - if len(inspect.getargspec(f).args) == 0: - raise NotImplementedError("0-parameter pandas_udfs are not currently supported") + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() @@ -2146,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): @since(1.3) def udf(f=None, returnType=StringType()): - """Creates a :class:`Column` expression representing a user defined function (UDF). + """Creates a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than @@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()): @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a user defined function (UDF) that accepts - `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + Creates a vectorized user defined function (UDF). - :param f: python function if used as a standalone function + :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ + The user-defined function can define one of the following transformations: + + 1. One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than a `Column` + transformation. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + .. note:: The user-defined function must be deterministic. """ return _create_udf(f, returnType=returnType, vectorized=True) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..817d0bc83bb77 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -54,9 +54,10 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx): + def __init__(self, jgd, df): self._jgd = jgd - self.sql_ctx = sql_ctx + self._df = df + self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @since(1.3) @@ -170,7 +171,7 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): """ - Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally. @@ -192,7 +193,85 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self._df) + + @since(2.3) + def apply(self, udf): + """ + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame`s are combined as a + :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + This function does not support partial aggregation, and requires shuffling all the data in + the :class:`DataFrame`. + + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + from pyspark.sql.functions import pandas_udf + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: + raise ValueError("The argument to apply must be a pandas_udf") + if not isinstance(udf.returnType, StructType): + raise ValueError("The returnType of the pandas_udf must be a StructType") + + df = self._df + func = udf.func + returnType = udf.returnType + + # The python executors expects the function to use pd.Series as input and output + # So we to create a wrapper function that turns that to a pd.DataFrame before passing + # down to the user function, then turn the result pd.DataFrame back into pd.Series + columns = df.columns + + def wrapped(*cols): + from pyspark.sql.types import to_arrow_type + import pandas as pd + result = func(pd.concat(cols, axis=1, keys=columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(returnType): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + wrapped_udf_obj = pandas_udf(wrapped, returnType) + udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) def _test(): @@ -206,6 +285,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a59378b5e848a..bac2ef84ae7a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3256,17 +3256,17 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - error_str = '0-parameter pandas_udfs.*not.*supported' + error_str = '0-arg pandas_udfs.*not.*supported' with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): pandas_udf(lambda: 1, LongType()) - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf def zero_no_type(): return 1 - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf(LongType()) def zero_with_type(): return 1 @@ -3348,7 +3348,7 @@ def test_vectorized_udf_wrong_return_type(self): df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3356,7 +3356,7 @@ def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): @@ -3376,6 +3376,151 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyApplyTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_simple(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_decorator(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_coerce(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', StringType())])) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + df.groupby('id').apply(foo).sort('id').toPandas() + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'returnType'): + df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744a..f65273d5f0b6c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) -def toArrowType(dt): +def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ import pyarrow as pa diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e24789cf010d..eb6d48688dc0a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import toArrowType +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,17 +74,28 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + # If the return_type is a StructType, it indicates this is a groupby apply udf, + # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. + # We can distinguish these two by return type because in groupby apply, we always specify + # returnType as a StructType, and in vectorized column udf, StructType is not supported. + # + # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs + if isinstance(return_type, StructType): + return lambda *a: f(*a) + else: + arrow_return_type = to_arrow_type(return_type) + + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result + + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bc2d4a824cb49..d829e01441dcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -452,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala new file mode 100644 index 0000000000000..8abab24bc9b44 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} + +/** + * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by DataFrame.groupby().apply(). + */ +case class FlatMapGroupsInPandas( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b549964913..cd0ac1feffa51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NumericType, StructField, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -435,6 +435,36 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan.output, df.logicalPlan)) } + + /** + * Applies a vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.dataType.isInstanceOf[StructType], + "The returnType of the vectorized python udf must be a StructType") + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + val child = df.logicalPlan + val project = Project(groupingNamedExpressions ++ child.output, child) + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } } private[sql] object RelationalGroupedDataset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 92eaab5cd8f81..4cdcc73faacd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInPandas(grouping, func, output, child) => + execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index f7e8cbe416121..81896187ecc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,6 +26,35 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType +/** + * Grouped a iterator into batches. + * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T]. + * This is necessary because sometimes we cannot hold reference of input rows + * because the some input rows are mutable and can be reused. + */ +private class BatchIterator[T](iter: Iterator[T], batchSize: Int) + extends Iterator[Iterator[T]] { + + override def hasNext: Boolean = iter.hasNext + + override def next(): Iterator[T] = { + new Iterator[T] { + var count = 0 + + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + if (!hasNext) { + Iterator.empty.next() + } else { + count += 1 + iter.next() + } + } + } + } +} + /** * A physical plan that evaluates a [[PythonUDF]], */ @@ -44,14 +73,18 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) + val batchSize = conf.arrowMaxRecordsPerBatch + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + val columnarBatchIter = new ArrowPythonRunner( - funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) - .compute(iter, context.partitionId(), context) + .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { - var currentIter = if (columnarBatchIter.hasNext) { + private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() assert(schemaOut.equals(batch.schema), s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index bbad9d6b631fd..f6c03c415dc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,19 +39,18 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - batchSize: Int, bufferSize: Int, reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType) - extends BasePythonRunner[InternalRow, ColumnarBatch]( + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[InternalRow], + inputIterator: Iterator[Iterator[InternalRow]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -82,12 +81,12 @@ class ArrowPythonRunner( Utils.tryWithSafeFinally { while (inputIterator.hasNext) { - var rowCount = 0 - while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { - val row = inputIterator.next() - arrowWriter.write(row) - rowCount += 1 + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) } + arrowWriter.finish() writer.writeBatch() arrowWriter.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe2..e3f952e221d53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -111,6 +110,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { + // FlatMapGroupsInPandas can be evaluated directly in python worker + // Therefore we don't need to extract the UDFs + case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -169,7 +171,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - execution.ProjectExec(plan.output, newPlan) + ProjectExec(plan.output, newPlan) } else { newPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala new file mode 100644 index 0000000000000..b996b5bb38ba5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapGroupsInPandasExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + } + } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(grouped, context.partitionId(), context) + + columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + } + } +} From bd4eb9ce57da7bacff69d9ed958c94f349b7e6fb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Oct 2017 15:50:37 -0700 Subject: [PATCH 044/712] [SPARK-19558][SQL] Add config key to register QueryExecutionListeners automatically. This change adds a new SQL config key that is equivalent to SparkContext's "spark.extraListeners", allowing users to register QueryExecutionListener instances through the Spark configuration system instead of having to explicitly do it in code. The code used by SparkContext to implement the feature was refactored into a helper method in the Utils class, and SQL's ExecutionListenerManager was modified to use it to initialize listener declared in the configuration. Unit tests were added to verify all the new functionality. Author: Marcelo Vanzin Closes #19309 from vanzin/SPARK-19558. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++-------- .../spark/internal/config/package.scala | 7 ++ .../scala/org/apache/spark/util/Utils.scala | 57 ++++++++++++++- .../spark/scheduler/SparkListenerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 56 ++++++++++++++- .../spark/sql/internal/StaticSQLConf.scala | 8 +++ .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/util/QueryExecutionListener.scala | 12 +++- .../util/ExecutionListenerManagerSuite.scala | 69 +++++++++++++++++++ 9 files changed, 216 insertions(+), 40 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3cd03c0cfbe1..6f25d346e6e54 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2344,41 +2344,13 @@ class SparkContext(config: SparkConf) extends Logging { * (e.g. after the web UI and event logging listeners have been registered). */ private def setupAndStartListenerBus(): Unit = { - // Use reflection to instantiate listeners specified via `spark.extraListeners` try { - val listenerClassNames: Seq[String] = - conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") - for (className <- listenerClassNames) { - // Use reflection to find the right constructor - val constructors = { - val listenerClass = Utils.classForName(className) - listenerClass - .getConstructors - .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]] + conf.get(EXTRA_LISTENERS).foreach { classNames => + val listeners = Utils.loadExtensions(classOf[SparkListenerInterface], classNames, conf) + listeners.foreach { listener => + listenerBus.addToSharedQueue(listener) + logInfo(s"Registered listener ${listener.getClass().getName()}") } - val constructorTakingSparkConf = constructors.find { c => - c.getParameterTypes.sameElements(Array(classOf[SparkConf])) - } - lazy val zeroArgumentConstructor = constructors.find { c => - c.getParameterTypes.isEmpty - } - val listener: SparkListenerInterface = { - if (constructorTakingSparkConf.isDefined) { - constructorTakingSparkConf.get.newInstance(conf) - } else if (zeroArgumentConstructor.isDefined) { - zeroArgumentConstructor.get.newInstance() - } else { - throw new SparkException( - s"$className did not have a zero-argument constructor or a" + - " single-argument constructor that accepts SparkConf. Note: if the class is" + - " defined inside of another Scala class, then its constructors may accept an" + - " implicit parameter that references the enclosing class; in this case, you must" + - " define the listener as a top-level class in order to prevent this extra" + - " parameter from breaking Spark's ability to find a valid constructor.") - } - } - listenerBus.addToSharedQueue(listener) - logInfo(s"Registered listener $className") } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5278e5e0fb270..19336f854145f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -419,4 +419,11 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") + .doc("Class names of listeners to add to SparkContext during initialization.") + .stringConf + .toSequence + .createOptional + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 836e33c36d9a1..930e09d90c2f5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import scala.util.matching.Regex @@ -2687,6 +2688,60 @@ private[spark] object Utils extends Logging { def stringToSeq(str: String): Seq[String] = { str.split(",").map(_.trim()).filter(_.nonEmpty) } + + /** + * Create instances of extension classes. + * + * The classes in the given list must: + * - Be sub-classes of the given base class. + * - Provide either a no-arg constructor, or a 1-arg constructor that takes a SparkConf. + * + * The constructors are allowed to throw "UnsupportedOperationException" if the extension does not + * want to be registered; this allows the implementations to check the Spark configuration (or + * other state) and decide they do not need to be added. A log message is printed in that case. + * Other exceptions are bubbled up. + */ + def loadExtensions[T](extClass: Class[T], classes: Seq[String], conf: SparkConf): Seq[T] = { + classes.flatMap { name => + try { + val klass = classForName(name) + require(extClass.isAssignableFrom(klass), + s"$name is not a subclass of ${extClass.getName()}.") + + val ext = Try(klass.getConstructor(classOf[SparkConf])) match { + case Success(ctor) => + ctor.newInstance(conf) + + case Failure(_) => + klass.getConstructor().newInstance() + } + + Some(ext.asInstanceOf[T]) + } catch { + case _: NoSuchMethodException => + throw new SparkException( + s"$name did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the class as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + + case e: InvocationTargetException => + e.getCause() match { + case uoe: UnsupportedOperationException => + logDebug(s"Extension $name not being initialized.", uoe) + logInfo(s"Extension $name not being initialized.") + None + + case null => throw e + + case cause => throw cause + } + } + } + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d061c7845f4a6..1beb36afa95f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} @@ -446,13 +446,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match classOf[FirehoseListenerThatAcceptsSparkConf], classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) + .set(EXTRA_LISTENERS, listeners.map(_.getName)) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) sc.listenerBus.listeners.asScala - .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } test("add and remove listeners to/from LiveListenerBus queues") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 2b16cc4852ba8..4d3adeb968e84 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,9 +38,10 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -1110,4 +1111,57 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) TaskContext.unset } + + test("load extensions") { + val extensions = Seq( + classOf[SimpleExtension], + classOf[ExtensionWithConf], + classOf[UnregisterableExtension]).map(_.getName()) + + val conf = new SparkConf(false) + val instances = Utils.loadExtensions(classOf[Object], extensions, conf) + assert(instances.size === 2) + assert(instances.count(_.isInstanceOf[SimpleExtension]) === 1) + + val extWithConf = instances.find(_.isInstanceOf[ExtensionWithConf]) + .map(_.asInstanceOf[ExtensionWithConf]) + .get + assert(extWithConf.conf eq conf) + + class NestedExtension { } + + val invalid = Seq(classOf[NestedExtension].getName()) + intercept[SparkException] { + Utils.loadExtensions(classOf[Object], invalid, conf) + } + + val error = Seq(classOf[ExtensionWithError].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Object], error, conf) + } + + val wrongType = Seq(classOf[ListenerImpl].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Seq[_]], wrongType, conf) + } + } + +} + +private class SimpleExtension + +private class ExtensionWithConf(val conf: SparkConf) + +private class UnregisterableExtension { + + throw new UnsupportedOperationException() + +} + +private class ExtensionWithError { + + throw new IllegalArgumentException() + } + +private class ListenerImpl extends SparkListener diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index c6c0a605d89ff..c018fc8a332fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -87,4 +87,12 @@ object StaticSQLConf { "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") .stringConf .createOptional + + val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners") + .doc("List of class names implementing QueryExecutionListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4e756084bbdbb..2867b4cd7da5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,7 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise is a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + parentState.map(_.listenerManager.clone()).getOrElse( + new ExecutionListenerManager(session.sparkContext.conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index f6240d85fba6f..2b46233e1a5df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,9 +22,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal +import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -72,7 +75,14 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private[sql] () extends Logging { +class ExecutionListenerManager private extends Logging { + + private[sql] def this(conf: SparkConf) = { + this() + conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) + } + } /** * Registers the specified [[QueryExecutionListener]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala new file mode 100644 index 0000000000000..4205e23ae240a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ + +class ExecutionListenerManagerSuite extends SparkFunSuite { + + import CountingQueryExecutionListener._ + + test("register query execution listeners using configuration") { + val conf = new SparkConf(false) + .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + + val mgr = new ExecutionListenerManager(conf) + assert(INSTANCE_COUNT.get() === 1) + mgr.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 1) + + val clone = mgr.clone() + assert(INSTANCE_COUNT.get() === 1) + + clone.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 2) + } + +} + +private class CountingQueryExecutionListener extends QueryExecutionListener { + + import CountingQueryExecutionListener._ + + INSTANCE_COUNT.incrementAndGet() + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + +} + +private object CountingQueryExecutionListener { + + val CALLBACK_COUNT = new AtomicInteger() + val INSTANCE_COUNT = new AtomicInteger() + +} From 76fb173dd639baa9534486488155fc05a71f850e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 10 Oct 2017 20:29:02 -0700 Subject: [PATCH 045/712] [SPARK-21751][SQL] CodeGeneraor.splitExpressions counts code size more precisely ## What changes were proposed in this pull request? Current `CodeGeneraor.splitExpressions` splits statements into methods if the total length of statements is more than 1024 characters. The length may include comments or empty line. This PR excludes comment or empty line from the length to reduce the number of generated methods in a class, by using `CodeFormatter.stripExtraNewLinesAndComments()` method. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #18966 from kiszk/SPARK-21751. --- .../expressions/codegen/CodeFormatter.scala | 8 +++++ .../expressions/codegen/CodeGenerator.scala | 5 ++- .../codegen/CodeFormatterSuite.scala | 32 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 60e600d8dbd8f..7b398f424cead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,6 +89,14 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } + + def stripExtraNewLinesAndComments(input: String): String = { + val commentReg = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val codeWithoutComment = commentReg.replaceAllIn(input, "") + codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f9c5ef8439085..2cb66599076a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -772,16 +772,19 @@ class CodegenContext { foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() + var length = 0 for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. - if (blockBuilder.length > 1024) { + if (length > 1024) { blocks += blockBuilder.toString() blockBuilder.clear() + length = 0 } blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length } blocks += blockBuilder.toString() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9d0a41661beaa..a0f1a64b0ab08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } + test("removing extra new lines and comments") { + val code = + """ + |/* + | * multi + | * line + | * comments + | */ + | + |public function() { + |/*comment*/ + | /*comment_with_space*/ + |code_body + |//comment + |code_body + | //comment_with_space + | + |code_body + |} + """.stripMargin + + val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) + assert(reducedCode === + """ + |public function() { + |code_body + |code_body + |code_body + |} + """.stripMargin) + } + testCase("basic example") { """ |class A { From 655f6f86f84ff5241d1d20766e1ef83bb32ca5e0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 11 Oct 2017 00:16:12 -0700 Subject: [PATCH 046/712] [SPARK-22208][SQL] Improve percentile_approx by not rounding up targetError and starting from index 0 ## What changes were proposed in this pull request? Currently percentile_approx never returns the first element when percentile is in (relativeError, 1/N], where relativeError default 1/10000, and N is the total number of elements. But ideally, percentiles in [0, 1/N] should all return the first element as the answer. For example, given input data 1 to 10, if a user queries 10% (or even less) percentile, it should return 1, because the first value 1 already reaches 10%. Currently it returns 2. Based on the paper, targetError is not rounded up, and searching index should start from 0 instead of 1. By following the paper, we should be able to fix the cases mentioned above. ## How was this patch tested? Added a new test case and fix existing test cases. Author: Zhenhua Wang Closes #19438 from wzhfy/improve_percentile_approx. --- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++---- .../apache/spark/ml/feature/ImputerSuite.scala | 2 +- python/pyspark/sql/dataframe.py | 6 +++--- .../sql/catalyst/util/QuantileSummaries.scala | 4 ++-- .../catalyst/util/QuantileSummariesSuite.scala | 10 ++++++++-- .../sql/ApproximatePercentileQuerySuite.scala | 17 ++++++++++++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 36 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bbea25bc4da5c..4382ef2ed4525 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2538,7 +2538,7 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[5, "age"], "19") stats3 <- summary(df, "min", "max", "55.1%") @@ -2738,7 +2738,7 @@ test_that("sampleBy() on a DataFrame", { }) test_that("approxQuantile() on a DataFrame", { - l <- lapply(c(0:99), function(i) { list(i, 99 - i) }) + l <- lapply(c(0:100), function(i) { list(i, 100 - i) }) df <- createDataFrame(l, list("a", "b")) quantiles <- approxQuantile(df, "a", c(0.5, 0.8), 0.0) expect_equal(quantiles, list(50, 80)) @@ -2749,8 +2749,8 @@ test_that("approxQuantile() on a DataFrame", { dfWithNA <- createDataFrame(data.frame(a = c(NA, 30, 19, 11, 28, 15), b = c(-30, -19, NA, -11, -28, -15))) quantiles3 <- approxQuantile(dfWithNA, c("a", "b"), c(0.5), 0.0) - expect_equal(quantiles3[[1]], list(28)) - expect_equal(quantiles3[[2]], list(-15)) + expect_equal(quantiles3[[1]], list(19)) + expect_equal(quantiles3[[2]], list(-19)) }) test_that("SQL error message is returned from JVM", { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index ee2ba73fa96d5..c08b35b419266 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -43,7 +43,7 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default (0, 1.0, 1.0, 1.0), (1, 3.0, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 3.0) + (3, -1.0, 2.0, 1.0) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d596229ced7e..38b01f0011671 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1038,8 +1038,8 @@ def summary(self, *statistics): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5| null| - | 50%| 5| null| + | 25%| 2| null| + | 50%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1050,7 +1050,7 @@ def summary(self, *statistics): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%| 5| null| + | 25%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index af543b04ba780..eb7941cf9e6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -193,10 +193,10 @@ class QuantileSummaries( // Target rank val rank = math.ceil(quantile * count).toInt - val targetError = math.ceil(relativeError * count) + val targetError = relativeError * count // Minimum rank at current sample var minRank = 0 - var i = 1 + var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) minRank += curSample.g diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index df579d5ec1ddf..650813975d75c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -57,8 +57,14 @@ class QuantileSummariesSuite extends SparkFunSuite { private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { if (data.nonEmpty) { val approx = summary.query(quant).get - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact + // Get the rank of the approximation. + val rankOfValue = data.count(_ <= approx) + val rankOfPreValue = data.count(_ < approx) + // `rankOfValue` is the last position of the quantile value. If the input repeats the value + // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's + // improper to choose the last position as its rank. Instead, we get the rank by averaging + // `rankOfValue` and `rankOfPreValue`. + val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) val lower = math.floor((quant - summary.relativeError) * data.size) val upper = math.ceil((quant + summary.relativeError) * data.size) val msg = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 1aea33766407f..137c5bea2abb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -53,6 +53,21 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, the first element satisfies small percentages") { + withTempView(table) { + (1 to 10).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s""" + |SELECT + | percentile_approx(col, array(0.01, 0.1, 0.11)) + |FROM $table + """.stripMargin), + Row(Seq(1, 1, 2)) + ) + } + } + test("percentile_approx, array of percentile value") { withTempView(table) { (1 to 1000).toDF("col").createOrReplaceTempView(table) @@ -130,7 +145,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { (1 to 1000).toDF("col").createOrReplaceTempView(table) checkAnswer( spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"), - Row(Seq(500D)) + Row(Seq(499)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 247c30e2ee65b..46b21c3b64a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -141,7 +141,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("approximate quantile") { val n = 1000 - val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles") val q1 = 0.5 val q2 = 0.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd8f54b690f64..ad461fa6144b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -855,7 +855,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24", "176"), + Row("25%", null, "16", "164"), Row("50%", null, "24", "176"), Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) From 645e108eeb6364e57f5d7213dbbd42dbcf1124d3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Oct 2017 13:51:33 -0700 Subject: [PATCH 047/712] [SPARK-21988][SS] Implement StreamingRelation.computeStats to fix explain ## What changes were proposed in this pull request? Implement StreamingRelation.computeStats to fix explain ## How was this patch tested? - unit tests: `StreamingRelation.computeStats` and `StreamingExecutionRelation.computeStats`. - regression tests: `explain join with a normal source` and `explain join with MemoryStream`. Author: Shixiong Zhu Closes #19465 from zsxwing/SPARK-21988. --- .../streaming/StreamingRelation.scala | 8 +++ .../spark/sql/streaming/StreamSuite.scala | 65 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index ab716052c28ba..6b82c78ea653d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -44,6 +44,14 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) + ) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9c901062d570a..3d687d2214e90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -76,20 +76,65 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("StreamingRelation.computeStats") { + val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) + } - test("explain join") { - // Make a table and ensure it will be broadcast. - val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + test("StreamingExecutionRelation.computeStats") { + val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { + case s: StreamingExecutionRelation => s + } + assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingExecutionRelation.head.computeStats.sizeInBytes + == spark.sessionState.conf.defaultSizeInBytes) + } - // Join the input stream with a table. - val inputData = MemoryStream[Int] - val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + test("explain join with a normal source") { + // This test triggers CostBasedJoinReorder to call `computeStats` + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = spark.readStream.format("rate").load() + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) + } + } - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - joined.explain() + test("explain join with MemoryStream") { + // This test triggers CostBasedJoinReorder to call `computeStats` + // Because MemoryStream doesn't use DataSource code path, we need a separate test. + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = MemoryStream[Int].toDF + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) } - assert(outputStream.toString.contains("StreamingRelation")) } test("SPARK-20432: union one stream with itself") { From ccdf21f56e4ff5497d7770dcbee2f7a60bb9e3a7 Mon Sep 17 00:00:00 2001 From: Jorge Machado Date: Wed, 11 Oct 2017 22:13:07 -0700 Subject: [PATCH 048/712] [SPARK-20055][DOCS] Added documentation for loading csv files into DataFrames ## What changes were proposed in this pull request? Added documentation for loading csv files into Dataframes ## How was this patch tested? /dev/run-tests Author: Jorge Machado Closes #19429 from jomach/master. --- docs/sql-programming-guide.md | 32 ++++++++++++++++--- .../sql/JavaSQLDataSourceExample.java | 7 ++++ examples/src/main/python/sql/datasource.py | 5 +++ examples/src/main/r/RSparkSQLExample.R | 6 ++++ examples/src/main/resources/people.csv | 3 ++ .../examples/sql/SQLDataSourceExample.scala | 8 +++++ 6 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 examples/src/main/resources/people.csv diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a095263bfa619..639a8ea7bb8ad 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -461,6 +461,8 @@ name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can al names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data source type can be converted into other types using this syntax. +To load a JSON file you can use: +
{% include_example manual_load_options scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -479,6 +481,26 @@ source type can be converted into other types using this syntax.
+To load a CSV file you can use: + +
+
+{% include_example manual_load_options_csv scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example manual_load_options_csv java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example manual_load_options_csv python/sql/datasource.py %} +
+ +
+{% include_example manual_load_options_csv r/RSparkSQLExample.R %} + +
+
### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that @@ -573,7 +595,7 @@ Note that partition information is not gathered by default when creating externa ### Bucketing, Sorting and Partitioning -For file-based data source, it is also possible to bucket and sort or partition the output. +For file-based data source, it is also possible to bucket and sort or partition the output. Bucketing and sorting are applicable only to persistent tables:
@@ -598,7 +620,7 @@ CREATE TABLE users_bucketed_by_name( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet CLUSTERED BY(name) INTO 42 BUCKETS; {% endhighlight %} @@ -629,7 +651,7 @@ while partitioning can be used with both `save` and `saveAsTable` when using the {% highlight sql %} CREATE TABLE users_by_favorite_color( - name STRING, + name STRING, favorite_color STRING, favorite_numbers array ) USING csv PARTITIONED BY(favorite_color); @@ -664,7 +686,7 @@ CREATE TABLE users_bucketed_and_partitioned( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet PARTITIONED BY (favorite_color) CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; @@ -675,7 +697,7 @@ CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS;
`partitionBy` creates a directory structure as described in the [Partition Discovery](#partition-discovery) section. -Thus, it has limited applicability to columns with high cardinality. In contrast +Thus, it has limited applicability to columns with high cardinality. In contrast `bucketBy` distributes data across a fixed number of buckets and can be used when a number of unique values is unbounded. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 95859c52c2aeb..ef3c904775697 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -116,6 +116,13 @@ private static void runBasicDataSourceExample(SparkSession spark) { spark.read().format("json").load("examples/src/main/resources/people.json"); peopleDF.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + Dataset peopleDFCsv = spark.read().format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv"); + // $example off:manual_load_options_csv$ // $example on:direct_sql$ Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index f86012ea382e8..b375fa775de39 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -53,6 +53,11 @@ def basic_datasource_example(spark): df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ + # $example on:manual_load_options_csv$ + df = spark.read.load("examples/src/main/resources/people.csv", + format="csv", sep=":", inferSchema="true", header="true") + # $example off:manual_load_options_csv$ + # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 3734568d872d0..a5ed723da47ca 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -113,6 +113,12 @@ write.df(namesAndAges, "namesAndAges.parquet", "parquet") # $example off:manual_load_options$ +# $example on:manual_load_options_csv$ +df <- read.df("examples/src/main/resources/people.csv", "csv") +namesAndAges <- select(df, "name", "age") +# $example off:manual_load_options_csv$ + + # $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ diff --git a/examples/src/main/resources/people.csv b/examples/src/main/resources/people.csv new file mode 100644 index 0000000000000..7fe5adba93d77 --- /dev/null +++ b/examples/src/main/resources/people.csv @@ -0,0 +1,3 @@ +name;age;job +Jorge;30;Developer +Bob;32;Developer diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 86b3dc4a84f58..f9477969a4bb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -49,6 +49,14 @@ object SQLDataSourceExample { val peopleDF = spark.read.format("json").load("examples/src/main/resources/people.json") peopleDF.select("name", "age").write.format("parquet").save("namesAndAges.parquet") // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + val peopleDFCsv = spark.read.format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv") + // $example off:manual_load_options_csv$ + // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") // $example off:direct_sql$ From 274f0efefa0c063649bccddb787e8863910f4366 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:20:44 +0800 Subject: [PATCH 049/712] [SPARK-22252][SQL] FileFormatWriter should respect the input query schema ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/18064, we allowed `RunnableCommand` to have children in order to fix some UI issues. Then we made `InsertIntoXXX` commands take the input `query` as a child, when we do the actual writing, we just pass the physical plan to the writer(`FileFormatWriter.write`). However this is problematic. In Spark SQL, optimizer and planner are allowed to change the schema names a little bit. e.g. `ColumnPruning` rule will remove no-op `Project`s, like `Project("A", Scan("a"))`, and thus change the output schema from "" to ``. When it comes to writing, especially for self-description data format like parquet, we may write the wrong schema to the file and cause null values at the read path. Fortunately, in https://github.com/apache/spark/pull/18450 , we decided to allow nested execution and one query can map to multiple executions in the UI. This releases the major restriction in #18604 , and now we don't have to take the input `query` as child of `InsertIntoXXX` commands. So the fix is simple, this PR partially revert #18064 and make `InsertIntoXXX` commands leaf nodes again. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19474 from cloud-fan/bug. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/Command.scala | 3 +- .../spark/sql/execution/QueryExecution.scala | 4 +-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 3 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../command/DataWritingCommand.scala | 13 ++++++++ .../InsertIntoDataSourceDirCommand.scala | 6 ++-- .../spark/sql/execution/command/cache.scala | 2 +- .../sql/execution/command/commands.scala | 30 ++++++------------- .../command/createDataSourceTables.scala | 2 +- .../spark/sql/execution/command/views.scala | 4 +-- .../execution/datasources/DataSource.scala | 13 +++++++- .../datasources/FileFormatWriter.scala | 14 +++++---- .../InsertIntoDataSourceCommand.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 9 ++---- .../SaveIntoDataSourceCommand.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../datasources/FileFormatWriterSuite.scala | 16 +++++++++- .../execution/InsertIntoHiveDirCommand.scala | 10 ++----- .../hive/execution/InsertIntoHiveTable.scala | 14 ++------- .../sql/hive/execution/SaveAsHiveFile.scala | 6 ++-- 22 files changed, 85 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7addbaaa9afa5..c7952e3ff8280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override def innerChildren: Seq[QueryPlan[_]] = subqueries + override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index ec5766e1f67f2..38f47081b6f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LogicalPlan { +trait Command extends LeafNode { override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a18232..f404621399cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand, _) => + case ExecutedCommandExec(desc: DescribeTableCommand) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4cdcc73faacd7..19b858faba6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index bc98d8d9d6d61..a1c62a729900e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -62,7 +62,8 @@ case class InMemoryRelation( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override def innerChildren: Seq[SparkPlan] = Seq(child) + + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c7ddec55682e1..af3636a5a2ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 4e1c5e4846f36..2cf06982e25f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration @@ -30,6 +31,18 @@ import org.apache.spark.util.SerializableConfiguration */ trait DataWritingCommand extends RunnableCommand { + /** + * The input query plan that produces the data to be written. + */ + def query: LogicalPlan + + // We make the input `query` an inner child instead of a child in order to hide it from the + // optimizer. This is because optimizer may not preserve the output schema names' case, and we + // have to keep the original analyzed plan here so that we can pass the corrected schema to the + // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect + // it when writing. + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 633de4c37af94..9e3519073303c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ /** @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand( query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { - override def children: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty, "Directory path is required") assert(provider.nonEmpty, "Data source is required") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 792290bef0163..140f920eaafae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,7 +30,7 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq + override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7cd4baef89e75..e28b5eb2e2a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -37,19 +37,13 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends logical.Command { +trait RunnableCommand extends Command { // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty - def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - throw new NotImplementedError - } - - def run(sparkSession: SparkSession): Seq[Row] = { - throw new NotImplementedError - } + def run(sparkSession: SparkSession): Seq[Row] } /** @@ -57,9 +51,8 @@ trait RunnableCommand extends logical.Command { * saves the result to prevent multiple executions. * * @param cmd the `RunnableCommand` this operator will run. - * @param children the children physical plans ran by the `RunnableCommand`. */ -case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { override lazy val metrics: Map[String, SQLMetric] = cmd.metrics @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val rows = if (children.isEmpty) { - cmd.run(sqlContext.sparkSession) - } else { - cmd.run(sqlContext.sparkSession, children) - } - rows.map(converter(_).asInstanceOf[InternalRow]) + cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } - override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil override def output: Seq[Attribute] = cmd.output - override def nodeName: String = cmd.nodeName + override def nodeName: String = "Execute " + cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 04b2534ca5eb1..9e3907996995c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override def innerChildren: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index ffdfd527fa701..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -98,7 +98,7 @@ case class CreateViewCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -267,7 +267,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b9502a95a7c08..b43d282bd434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -453,6 +453,17 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val partitionAttributes = partitionColumns.map { name => + data.output.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location @@ -465,7 +476,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 514969715091a..75b1695fbc275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -117,7 +117,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = plan.output + // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema + // names' case. + val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) @@ -158,7 +160,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -176,12 +178,12 @@ object FileFormatWriter extends Logging { try { val rdd = if (orderingMatched) { - plan.execute() + queryExecution.toRdd } else { SortExec( requiredOrdering.map(SortOrder(_, Ascending)), global = false, - child = plan).execute() + child = queryExecution.executedPlan).execute() } val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 08b2f4f31170f..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 64e5a57adc37c..675bee85bf61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.util.SchemaUtils @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkSchemaColumnNameDuplication( query.schema, @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand( val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5eb6a8471be0d..96c84eab1c894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand( options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { dataSource.createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 72e5ac40bbfed..6bd0696622005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -121,7 +121,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - plan = data.queryExecution.executedPlan, + queryExecution = data.queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index a0c1ea63d3827..6f8767db176aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext class FileFormatWriterSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("empty file should be skipped while write to file") { withTempPath { path => @@ -30,4 +31,17 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { assert(partFiles.length === 2) } } + + test("FileFormatWriter should respect the input query schema") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 918c8be00d69d..1c6f8dd77fc2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -27,11 +27,10 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -57,10 +56,7 @@ case class InsertIntoHiveDirCommand( query: LogicalPlan, overwrite: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( @@ -102,7 +98,7 @@ case class InsertIntoHiveDirCommand( try { saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index e5b59ed7a1a6b..56e10bc457a00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive.execution -import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils -import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -72,16 +68,12 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { val externalCatalog = sparkSession.sharedState.externalCatalog val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -170,7 +162,7 @@ case class InsertIntoHiveTable( saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 2d74ef040ef5a..63657590e5e79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog @@ -47,7 +47,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { protected def saveAsHiveFile( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, @@ -75,7 +75,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.write( sparkSession = sparkSession, - plan = plan, + queryExecution = queryExecution, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), From b5c1ef7a8e4db4067bc361d10d554ee9a538423f Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 12 Oct 2017 20:26:51 +0800 Subject: [PATCH 050/712] [SPARK-22097][CORE] Request an accurate memory after we unrolled the block ## What changes were proposed in this pull request? We only need request `bbos.size - unrollMemoryUsedByThisBlock` after unrolled the block. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19316 from ConeyLiu/putIteratorAsBytes. --- .../org/apache/spark/storage/memory/MemoryStore.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 651e9c7b2ab61..17f7a69ad6ba1 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -388,7 +388,13 @@ private[spark] class MemoryStore( // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { serializationStream.close() - reserveAdditionalMemoryIfNecessary() + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } } if (keepUnrolling) { From 73d80ec49713605d6a589e688020f0fc2d6feab2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 20:34:03 +0800 Subject: [PATCH 051/712] [SPARK-22197][SQL] push down operators to data source before planning ## What changes were proposed in this pull request? As we discussed in https://github.com/apache/spark/pull/19136#discussion_r137023744 , we should push down operators to data source before planning, so that data source can report statistics more accurate. This PR also includes some cleanup for the read path. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19424 from cloud-fan/follow. --- .../spark/sql/sources/v2/ReadSupport.java | 5 +- .../sql/sources/v2/ReadSupportWithSchema.java | 5 +- .../sql/sources/v2/reader/DataReader.java | 4 + .../sources/v2/reader/DataSourceV2Reader.java | 2 +- .../spark/sql/sources/v2/reader/ReadTask.java | 3 +- .../SupportsPushDownCatalystFilters.java | 8 + .../v2/reader/SupportsPushDownFilters.java | 8 + .../apache/spark/sql/DataFrameReader.scala | 5 +- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../v2/DataSourceReaderHolder.scala | 68 +++++++++ .../datasources/v2/DataSourceV2Relation.scala | 8 +- .../datasources/v2/DataSourceV2ScanExec.scala | 22 +-- .../datasources/v2/DataSourceV2Strategy.scala | 60 +------- .../v2/PushDownOperatorsToDataSource.scala | 140 ++++++++++++++++++ .../sources/v2/JavaAdvancedDataSourceV2.java | 5 + .../sql/sources/v2/DataSourceV2Suite.scala | 2 + 16 files changed, 262 insertions(+), 87 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ab5254a688d5a..ee489ad0f608f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,9 +30,8 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index c13aeca2ef36f..74e81a2c84d68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -39,9 +39,8 @@ public interface ReadSupportWithSchema { * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full * schema, as column pruning or other optimizations may happen. - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index cfafc1a576793..95e091569b614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -24,6 +24,10 @@ /** * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data * for a RDD partition. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index fb4d5c0d7ae41..5989a4ac8440b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,7 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 7885bfcdd49e4..01362df0978cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -27,7 +27,8 @@ * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. + * created on executors and do the actual reading. So {@link ReadTask} must be serializable and + * {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving public interface ReadTask extends Serializable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 19d706238ec8e..d6091774d75aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -40,4 +40,12 @@ public interface SupportsPushDownCatalystFilters { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); + + /** + * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * It's possible that there is no filters in the query and + * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for + * this case. + */ + Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d4b509e7080f2..d6f297c013375 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** @@ -35,4 +36,11 @@ public interface SupportsPushDownFilters { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78b668c04fd5c..17966eecfc051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -184,7 +184,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() val options = new DataSourceV2Options(extraOptions.asJava) val reader = (cls.newInstance(), userSpecifiedSchema) match { @@ -194,8 +193,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case (ds: ReadSupport, None) => ds.createReader(options) - case (_: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + case (ds: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $ds.") case (ds: ReadSupport, Some(schema)) => val reader = ds.createReader(options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..1c8e4050978dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..6093df26630cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The full output of the data source reader, without column pruning. + */ + def fullOutput: Seq[AttributeReference] + + /** + * The held data source reader. + */ + def reader: DataSourceV2Reader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => + fullOutput.find(_.name == name).get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3c9b598fd07c9..7eb99a645001a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode { + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7999c0ceb5749..addc12a3f0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,20 +29,14 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType +/** + * Physical plan node for scanning data from a data source. + */ case class DataSourceV2ScanExec( - fullOutput: Array[AttributeReference], - @transient reader: DataSourceV2Reader, - // TODO: these 3 parameters are only used to determine the equality of the scan node, however, - // the reader also have this information, and ideally we can just rely on the equality of the - // reader. The only concern is, the reader implementation is outside of Spark and we have no - // control. - readSchema: StructType, - @transient filters: ExpressionSet, - hashPartitionKeys: Seq[String]) extends LeafExecNode { - - def output: Seq[Attribute] = readSchema.map(_.name).map { name => - fullOutput.find(_.name == name).get - } + fullOutput: Seq[AttributeReference], + @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] override def references: AttributeSet = AttributeSet.empty @@ -74,7 +68,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b80f695b2a87f..f2cda002245e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -29,64 +29,8 @@ import org.apache.spark.sql.sources.v2.reader._ object DataSourceV2Strategy extends Strategy { // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(filters.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => filters - } - - val attrMap = AttributeMap(output.zip(output)) - val projectSet = AttributeSet(projects.flatMap(_.references)) - val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) - - // Match original case of attributes. - // TODO: nested fields pruning - val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) - reader match { - case r: SupportsPushDownRequiredColumns => - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - val scan = DataSourceV2ScanExec( - output.toArray, - reader, - reader.readSchema(), - ExpressionSet(filters), - Nil) - - val filterCondition = stayUpFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - val withProject = if (projects == withFilter.output) { - withFilter - } else { - ProjectExec(projects, withFilter) - } - - withProject :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala new file mode 100644 index 0000000000000..0c1708131ae46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.reader._ + +/** + * Pushes down various operators to the underlying data source for better performance. Operators are + * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you + * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the + * data source should execute FILTER before LIMIT. And required columns are calculated at the end, + * because when more operators are pushed down, we may need less columns at Spark side. + */ +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may + // appear in many places for column pruning. + // TODO: Ideally column pruning should be implemented via a plan property that is propagated + // top-down, then we can simplify the logic here and only collect target operators. + val filterPushed = plan transformUp { + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + // Non-deterministic expressions are stateful and we must keep the input sequence unchanged + // to avoid changing the result. This means, we can't evaluate the filter conditions that + // are after the first non-deterministic condition ahead. Here we only try to push down + // deterministic conditions that are before the first non-deterministic condition. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(candidates.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => candidates + } + + val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + if (withFilter.output == fields) { + withFilter + } else { + Project(fields, withFilter) + } + } + + // TODO: add more push down rules. + + // TODO: nested fields pruning + def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.filter(requiredByParent.contains).flatMap(_.references) + pushDownRequiredColumns(child, required) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) + + case DataSourceV2Relation(fullOutput, reader) => reader match { + case r: SupportsPushDownRequiredColumns => + // Match original case of attributes. + val attrMap = AttributeMap(fullOutput.zip(fullOutput)) + val requiredColumns = requiredByParent.map(attrMap) + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + } + } + + pushDownRequiredColumns(filterPushed, filterPushed.output) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + /** + * Finds a Filter node(with an optional Project child) above data source relation. + */ + object FilterAndProject { + // returns the project list, the filter condition and the data source relation. + def unapply(plan: LogicalPlan) + : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + + case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + + case Filter(condition, Project(fields, r: DataSourceV2Relation)) + if fields.forall(_.deterministic) => + val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) + val substituted = condition.transform { + case a: Attribute => attributeMap.getOrElse(a, a) + } + Some((fields, substituted, r)) + + case _ => None + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 7aacf0346d2fb..da2c13f70c52a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -54,6 +54,11 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } + @Override + public Filter[] pushedFilters() { + return filters; + } + @Override public List> createReadTasks() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 9ce93d7ae926c..f238e565dc2fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -129,6 +129,8 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } + override def pushedFilters(): Array[Filter] = filters + override def readSchema(): StructType = { requiredSchema } From 02218c4c73c32741390d9906b6190ef2124ce518 Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Thu, 12 Oct 2017 17:00:22 +0200 Subject: [PATCH 052/712] [SPARK-22251][SQL] Metric 'aggregate time' is incorrect when codegen is off ## What changes were proposed in this pull request? Adding the code for setting 'aggregate time' metric to non-codegen path in HashAggregateExec and to ObjectHashAggregateExces. ## How was this patch tested? Tested manually. Author: Ala Luszczak Closes #19473 from ala/fix-agg-time. --- .../sql/execution/aggregate/HashAggregateExec.scala | 6 +++++- .../execution/aggregate/ObjectHashAggregateExec.scala | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8b573fdcf25e1..43e5ff89afee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -95,11 +95,13 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") + val aggTime = longMetric("aggTime") child.execute().mapPartitionsWithIndex { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator.empty @@ -128,6 +130,8 @@ case class HashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 6316e06a8f34e..ec3f9a05b5ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -76,7 +76,8 @@ case class ObjectHashAggregateExec( aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows") + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time") ) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -96,11 +97,13 @@ case class ObjectHashAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") + val aggTime = longMetric("aggTime") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, // so return an empty kvIterator. Iterator.empty @@ -127,6 +130,8 @@ case class ObjectHashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } From 9104add4c7c6b578df15b64a8533a1266f90734e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 08:40:26 +0900 Subject: [PATCH 053/712] [SPARK-22217][SQL] ParquetFileFormat to support arbitrary OutputCommitters ## What changes were proposed in this pull request? `ParquetFileFormat` to relax its requirement of output committer class from `org.apache.parquet.hadoop.ParquetOutputCommitter` or subclass thereof (and so implicitly Hadoop `FileOutputCommitter`) to any committer implementing `org.apache.hadoop.mapreduce.OutputCommitter` This enables output committers which don't write to the filesystem the way `FileOutputCommitter` does to save parquet data from a dataframe: at present you cannot do this. Before a committer which isn't a subclass of `ParquetOutputCommitter`, it checks to see if the context has requested summary metadata by setting `parquet.enable.summary-metadata`. If true, and the committer class isn't a parquet committer, it raises a RuntimeException with an error message. (It could downgrade, of course, but raising an exception makes it clear there won't be an summary. It also makes the behaviour testable.) Note that `SQLConf` already states that any `OutputCommitter` can be used, but that typically it's a subclass of ParquetOutputCommitter. That's not currently true. This patch will make the code consistent with the docs, adding tests to verify, ## How was this patch tested? The patch includes a test suite, `ParquetCommitterSuite`, with a new committer, `MarkingFileOutputCommitter` which extends `FileOutputCommitter` and writes a marker file in the destination directory. The presence of the marker file can be used to verify the new committer was used. The tests then try the combinations of Parquet committer summary/no-summary and marking committer summary/no-summary. | committer | summary | outcome | |-----------|---------|---------| | parquet | true | success | | parquet | false | success | | marking | false | success with marker | | marking | true | exception | All tests are happy. Author: Steve Loughran Closes #19448 from steveloughran/cloud/SPARK-22217-committer. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../parquet/ParquetFileFormat.scala | 12 +- .../parquet/ParquetCommitterSuite.scala | 152 ++++++++++++++++++ 3 files changed, 165 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58323740b80cc..618d4a0d6148a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -306,8 +306,9 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + - "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + + "will never be created, irrespective of the value of parquet.enable.summary-metadata") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e1e740500205a..c1535babbae1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -86,7 +86,7 @@ class ParquetFileFormat conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + @@ -98,7 +98,7 @@ class ParquetFileFormat conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why @@ -138,6 +138,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } + if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + } + new OutputWriterFactory { // This OutputWriterFactory instance is deserialized when writing Parquet files on the // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala new file mode 100644 index 0000000000000..caa4f6d70c6a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.FileNotFoundException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test logic related to choice of output committers. + */ +class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils + with LocalSparkContext { + + private val PARQUET_COMMITTER = classOf[ParquetOutputCommitter].getCanonicalName + + protected var spark: SparkSession = _ + + /** + * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.stop() + spark = null + } + } finally { + super.afterAll() + } + } + + test("alternative output committer, merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = true, check = true) + } + + test("alternative output committer, no merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = false, check = true) + } + + test("Parquet output committer, merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = true, check = false) + } + + test("Parquet output committer, no merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = false, check = false) + } + + /** + * Write a trivial dataframe as Parquet, using the given committer + * and job summary option. + * @param committer committer to use + * @param summary create a job summary + * @param check look for a marker file + * @return if a marker file was sought, it's file status. + */ + private def writeDataFrame( + committer: String, + summary: Boolean, + check: Boolean): Option[FileStatus] = { + var result: Option[FileStatus] = None + withSQLConf( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, + ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + withTempPath { dest => + val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) + val destPath = new Path(dest.toURI) + df.write.format("parquet").save(destPath.toString) + if (check) { + result = Some(MarkingFileOutput.checkMarker( + destPath, + spark.sparkContext.hadoopConfiguration)) + } + } + } + result + } +} + +/** + * A file output committer which explicitly touches a file "marker"; this + * is how tests can verify that this committer was used. + * @param outputPath output path + * @param context task context + */ +private class MarkingFileOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + super.commitJob(context) + MarkingFileOutput.touch(outputPath, context.getConfiguration) + } +} + +private object MarkingFileOutput { + + val COMMITTER = classOf[MarkingFileOutputCommitter].getCanonicalName + + /** + * Touch the marker. + * @param outputPath destination directory + * @param conf configuration to create the FS with + */ + def touch(outputPath: Path, conf: Configuration): Unit = { + outputPath.getFileSystem(conf).create(new Path(outputPath, "marker")).close() + } + + /** + * Get the file status of the marker + * + * @param outputPath destination directory + * @param conf configuration to create the FS with + * @return the status of the marker + * @throws FileNotFoundException if the marker is absent + */ + def checkMarker(outputPath: Path, conf: Configuration): FileStatus = { + outputPath.getFileSystem(conf).getFileStatus(new Path(outputPath, "marker")) + } +} From 3ff766f61afbd09dcc7a73eae02e68a39114ce3f Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 18:47:16 -0700 Subject: [PATCH 054/712] [SPARK-22263][SQL] Refactor deterministic as lazy value ## What changes were proposed in this pull request? The method `deterministic` is frequently called in optimizer. Refactor `deterministic` as lazy value, in order to avoid redundant computations. ## How was this patch tested? Simple benchmark test over TPC-DS queries, run time from query string to optimized plan(continuous 20 runs, and get the average of last 5 results): Before changes: 12601 ms After changes: 11993ms This is 4.8% performance improvement. Also run test with Unit test. Author: Wang Gengliang Closes #19478 from gengliangwang/deterministicAsLazyVal. --- .../sql/catalyst/expressions/CallMethodViaReflection.scala | 2 +- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/First.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/Last.scala | 2 +- .../spark/sql/catalyst/expressions/aggregate/collect.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 2 +- .../sql/execution/aggregate/TypedAggregateExpression.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 +- .../org/apache/spark/sql/TypedImperativeAggregateSuite.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index cd97304302e48..65bb9a8c642b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -76,7 +76,7 @@ case class CallMethodViaReflection(children: Seq[Expression]) } } - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = true override val dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c058425b4bc36..0e75ac88dc2b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -79,7 +79,7 @@ abstract class Expression extends TreeNode[Expression] { * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. */ - def deterministic: Boolean = children.forall(_.deterministic) + lazy val deterministic: Boolean = children.forall(_.deterministic) def nullable: Boolean @@ -265,7 +265,7 @@ trait NonSQLExpression extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - final override def deterministic: Boolean = false + final override lazy val deterministic: Boolean = false final override def foldable: Boolean = false @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 527f1670c25e1..179853032035e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,7 +49,7 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { - override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index bfc58c22886cc..4e671e1f3e6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,7 +44,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // First is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 96a6ec08a160a..0ccabb9d98914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,7 +44,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // Last is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 405c2065680f5..be972f006352e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -44,7 +44,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def update(buffer: T, input: InternalRow): T = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ef293ff3f18ea..b86e271fe2958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -119,7 +119,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:on line.size.limit case class Uuid() extends LeafExpression { - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 717758fdf716f..aab8cc50b9526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -127,7 +127,7 @@ case class SimpleTypedAggregateExpression( nullable: Boolean) extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer @@ -221,7 +221,7 @@ case class ComplexTypedAggregateExpression( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fec1add18cbf2..72aa4adff4e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -340,7 +340,7 @@ case class ScalaUDAF( override def dataType: DataType = udaf.dataType - override def deterministic: Boolean = udaf.deterministic + override lazy val deterministic: Boolean = udaf.deterministic override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b76f168220d84..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -268,7 +268,7 @@ object TypedImperativeAggregateSuite { } } - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e9bdcf00b9346..68af99ea272a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -48,7 +48,7 @@ private[hive] case class HiveSimpleUDF( with Logging with UserDefinedExpression { - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -131,7 +131,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] From ec122209fb35a65637df42eded64b0203e105aae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 13:09:35 +0800 Subject: [PATCH 055/712] [SPARK-21165][SQL] FileFormatWriter should handle mismatched attribute ids between logical and physical plan ## What changes were proposed in this pull request? Due to optimizer removing some unnecessary aliases, the logical and physical plan may have different output attribute ids. FileFormatWriter should handle this when creating the physical sort node. ## How was this patch tested? new regression test. Author: Wenchen Fan Closes #19483 from cloud-fan/bug2. --- .../datasources/FileFormatWriter.scala | 7 +++++- .../datasources/FileFormatWriterSuite.scala | 2 +- .../apache/spark/sql/hive/InsertSuite.scala | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 75b1695fbc275..1fac01a2c26c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -180,8 +180,13 @@ object FileFormatWriter extends Logging { val rdd = if (orderingMatched) { queryExecution.toRdd } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = requiredOrdering + .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), + orderingExpr, global = false, child = queryExecution.executedPlan).execute() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index 6f8767db176aa..13f0e0bca86c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -32,7 +32,7 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { } } - test("FileFormatWriter should respect the input query schema") { + test("SPARK-22252: FileFormatWriter should respect the input query schema") { withTable("t1", "t2", "t3", "t4") { spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index aa5cae33f5cd9..ab91727049ff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -728,4 +728,26 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter assert(e.contains("mismatched input 'ROW'")) } } + + test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } } From 2f00a71a876321af02865d7cd53ada167e1ce2e3 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 12 Oct 2017 22:45:19 -0700 Subject: [PATCH 056/712] [SPARK-22257][SQL] Reserve all non-deterministic expressions in ExpressionSet ## What changes were proposed in this pull request? For non-deterministic expressions, they should be considered as not contained in the [[ExpressionSet]]. This is consistent with how we define `semanticEquals` between two expressions. Otherwise, combining expressions will remove non-deterministic expressions which should be reserved. E.g. Combine filters of ```scala testRelation.where(Rand(0) > 0.1).where(Rand(0) > 0.1) ``` should result in ```scala testRelation.where(Rand(0) > 0.1 && Rand(0) > 0.1) ``` ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19475 from gengliangwang/non-deterministic-expressionSet. --- .../catalyst/expressions/ExpressionSet.scala | 23 ++++++--- .../expressions/ExpressionSetSuite.scala | 51 +++++++++++++++---- .../optimizer/FilterPushdownSuite.scala | 15 ++++++ 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 305ac90e245b8..7e8e7b8cd5f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -30,8 +30,9 @@ object ExpressionSet { } /** - * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] - * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * A [[Set]] where membership is determined based on determinacy and a canonical representation of + * an [[Expression]] (i.e. one that attempts to ignore cosmetic differences). + * See [[Canonicalize]] for more details. * * Internally this set uses the canonical representation, but keeps also track of the original * expressions to ease debugging. Since different expressions can share the same canonical @@ -46,6 +47,10 @@ object ExpressionSet { * set.contains(1 + a) => true * set.contains(a + 2) => false * }}} + * + * For non-deterministic expressions, they are always considered as not contained in the [[Set]]. + * On adding a non-deterministic expression, simply append it to the original expressions. + * This is consistent with how we define `semanticEquals` between two expressions. */ class ExpressionSet protected( protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, @@ -53,7 +58,9 @@ class ExpressionSet protected( extends Set[Expression] { protected def add(e: Expression): Unit = { - if (!baseSet.contains(e.canonicalized)) { + if (!e.deterministic) { + originals += e + } else if (!baseSet.contains(e.canonicalized) ) { baseSet.add(e.canonicalized) originals += e } @@ -74,9 +81,13 @@ class ExpressionSet protected( } override def -(elem: Expression): ExpressionSet = { - val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) - val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) - new ExpressionSet(newBaseSet, newOriginals) + if (elem.deterministic) { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } else { + new ExpressionSet(baseSet.clone(), originals.clone()) + } } override def iterator: Iterator[Expression] = originals.iterator diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index a1000a0e80799..12eddf557109f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -175,20 +175,14 @@ class ExpressionSetSuite extends SparkFunSuite { aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) - // Partial reorder case: we don't reorder non-deterministic expressions, - // but we can reorder sub-expressions in deterministic AND/OR expressions. - // There are two predicates: - // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. - // (aUpper === Rand(1L)) - setTest(1, + // Keep all the non-deterministic expressions even they are semantically equal. + setTest(2, Rand(1L), Rand(1L)) + + setTest(2, (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) - // There are three predicates: - // (Rand(1L) > aUpper) - // (aUpper <= Rand(1L) && aUpper > bUpper) - // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. - setTest(1, + setTest(2, Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) @@ -219,4 +213,39 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet ++ setToAddWithSameExpression).size == 2) assert((initialSet ++ setToAddWithOutSameExpression).size == 3) } + + test("add single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet + (aUpper + 1)).size == 2) + assert((initialSet + Rand(0)).size == 3) + assert((initialSet + (aUpper + 2)).size == 3) + } + + test("remove single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet - (aUpper + 1)).size == 1) + assert((initialSet - Rand(0)).size == 2) + assert((initialSet - (aUpper + 2)).size == 2) + } + + test("add multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameDeterministicExpression).size == 3) + assert((initialSet ++ setToAddWithOutSameExpression).size == 4) + } + + test("remove multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet -- setToRemoveWithSameDeterministicExpression).size == 1) + assert((initialSet -- setToRemoveWithOutSameExpression).size == 2) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 582b3ead5e54a..de0e7c7ee49ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,6 +94,21 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("combine redundant deterministic filters") { + val originalQuery = + testRelation + .where(Rand(0) > 0.1 && 'a === 1) + .where(Rand(0) > 0.1 && 'a === 1) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { val originalQuery = testRelation From e6e36004afc3f9fc8abea98542248e9de11b4435 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 13 Oct 2017 23:09:12 +0800 Subject: [PATCH 057/712] [SPARK-14387][SPARK-16628][SPARK-18355][SQL] Use Spark schema to read ORC table instead of ORC file schema ## What changes were proposed in this pull request? Before Hive 2.0, ORC File schema has invalid column names like `_col1` and `_col2`. This is a well-known limitation and there are several Apache Spark issues with `spark.sql.hive.convertMetastoreOrc=true`. This PR ignores ORC File schema and use Spark schema. ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #19470 from dongjoon-hyun/SPARK-18355. --- .../spark/sql/hive/orc/OrcFileFormat.scala | 31 ++++++---- .../sql/hive/execution/SQLQuerySuite.scala | 62 ++++++++++++++++++- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index c76f0ebb36a60..194e69c93e1a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -134,12 +134,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) - if (maybePhysicalSchema.isEmpty) { + val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + if (isEmptyFile) { Iterator.empty } else { - val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -163,6 +162,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, + dataSchema, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), recordsIterator) @@ -272,25 +272,32 @@ private[orc] object OrcRelation extends HiveInspectors { def unwrapOrcStructs( conf: Configuration, dataSchema: StructType, + requiredSchema: StructType, maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(dataSchema) + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(requiredSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { - case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { + case (field, ordinal) => + var ref = oi.getStructFieldRef(field.name) + if (ref == null) { + ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) + } + ref -> ordinal }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) + val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 val length = fieldRefs.length while (i < length) { - val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + val fieldRef = fieldRefs(i) + val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) } else { @@ -306,8 +313,8 @@ private[orc] object OrcRelation extends HiveInspectors { } def setRequiredColumns( - conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 09c59000b3e3f..94fa43dec7313 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -2050,4 +2050,64 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet").foreach { format => + test(s"SPARK-18355 Read data from a hive table with a new column - $format") { + val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + Seq("true", "false").foreach { value => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> value, + HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) { + withTempDatabase { db => + client.runSqlHive( + s""" + |CREATE TABLE $db.t( + | click_id string, + | search_id string, + | uid bigint) + |PARTITIONED BY ( + | ts string, + | hour string) + |STORED AS $format + """.stripMargin) + + client.runSqlHive( + s""" + |INSERT INTO TABLE $db.t + |PARTITION (ts = '98765', hour = '01') + |VALUES (12, 2, 12345) + """.stripMargin + ) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"), + Row("12", "2", 12345, "98765", "01")) + + client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)") + + checkAnswer( + sql(s"SELECT click_id, search_id FROM $db.t"), + Row("12", "2")) + + checkAnswer( + sql(s"SELECT search_id, click_id FROM $db.t"), + Row("2", "12")) + + checkAnswer( + sql(s"SELECT search_id FROM $db.t"), + Row("2")) + + checkAnswer( + sql(s"SELECT dummy, click_id FROM $db.t"), + Row(null, "12")) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"), + Row("12", "2", 12345, null, "98765", "01")) + } + } + } + } + } } From 6412ea1759d39a2380c572ec24cfd8ae4f2d81f7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 14 Oct 2017 00:35:12 +0800 Subject: [PATCH 058/712] [SPARK-21247][SQL] Type comparison should respect case-sensitive SQL conf ## What changes were proposed in this pull request? This is an effort to reduce the difference between Hive and Spark. Spark supports case-sensitivity in columns. Especially, for Struct types, with `spark.sql.caseSensitive=true`, the following is supported. ```scala scala> sql("select named_struct('a', 1, 'A', 2).a").show +--------------------------+ |named_struct(a, 1, A, 2).a| +--------------------------+ | 1| +--------------------------+ scala> sql("select named_struct('a', 1, 'A', 2).A").show +--------------------------+ |named_struct(a, 1, A, 2).A| +--------------------------+ | 2| +--------------------------+ ``` And vice versa, with `spark.sql.caseSensitive=false`, the following is supported. ```scala scala> sql("select named_struct('a', 1).A, named_struct('A', 1).a").show +--------------------+--------------------+ |named_struct(a, 1).A|named_struct(A, 1).a| +--------------------+--------------------+ | 1| 1| +--------------------+--------------------+ ``` However, types are considered different. For example, SET operations fail. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. struct <> struct at the first column of the second table;; 'Union :- Project [named_struct(a, 1) AS named_struct(a, 1)#57] : +- OneRowRelation$ +- Project [named_struct(A, 2) AS named_struct(A, 2)#58] +- OneRowRelation$ ``` This PR aims to support case-insensitive type equality. For example, in Set operation, the above operation succeed when `spark.sql.caseSensitive=false`. ```scala scala> sql("SELECT named_struct('a',1) union all (select named_struct('A',2))").show +------------------+ |named_struct(a, 1)| +------------------+ | [1]| | [2]| +------------------+ ``` ## How was this patch tested? Pass the Jenkins with a newly add test case. Author: Dongjoon Hyun Closes #18460 from dongjoon-hyun/SPARK-21247. --- .../sql/catalyst/analysis/TypeCoercion.scala | 10 ++++ .../org/apache/spark/sql/types/DataType.scala | 7 ++- .../catalyst/analysis/TypeCoercionSuite.scala | 52 +++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 38 ++++++++++++++ 4 files changed, 102 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9ffe646b5e4ec..532d22dbf2321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -100,6 +100,16 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 30745c6a9d42a..d6e0df12218ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -80,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d62e3b6dfe34f..793e04f66f0f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -131,14 +131,17 @@ class TypeCoercionSuite extends AnalysisTest { widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } } test("implicit type cast - ByteType") { @@ -385,6 +388,47 @@ class TypeCoercionSuite extends AnalysisTest { widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) + + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("b", IntegerType))), + None) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", DoubleType, nullable = false))), + None) + + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = false))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("A", IntegerType))), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkWidenType( + TypeCoercion.findTightestCommonType, + StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))), + StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))), + Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), + isSymmetric = false) + } } test("wider common type for decimal and array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 93a7777b70b46..f0c58e2e5bf45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2646,6 +2646,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21247: Allow case-insensitive type equality in Set operation") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { + case (left, right) => + checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) + } + } + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val m1 = intercept[AnalysisException] { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + }.message + assert(m1.contains("Union can only be performed on tables with the compatible column types")) + + val m2 = intercept[AnalysisException] { + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + }.message + assert(m2.contains("Except can only be performed on tables with the compatible column types")) + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) + val m = intercept[AnalysisException] { + sql("SELECT * FROM t, S WHERE c = C") + }.message + assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + } + } + } + test("SPARK-21335: support un-aliased subquery") { withTempView("v") { Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") From 3823dc88d3816c7d1099f9601426108acc90574c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Oct 2017 10:49:48 -0700 Subject: [PATCH 059/712] [SPARK-22252][SQL][FOLLOWUP] Command should not be a LeafNode ## What changes were proposed in this pull request? This is a minor folllowup of #19474 . #19474 partially reverted #18064 but accidentally introduced a behavior change. `Command` extended `LogicalPlan` before #18064 , but #19474 made it extend `LeafNode`. This is an internal behavior change as now all `Command` subclasses can't define children, and they have to implement `computeStatistic` method. This PR fixes this by making `Command` extend `LogicalPlan` ## How was this patch tested? N/A Author: Wenchen Fan Closes #19493 from cloud-fan/minor. --- .../org/apache/spark/sql/catalyst/plans/logical/Command.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6f55..ec5766e1f67f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } From 1bb8b76045420b61a37806d6f4765af15c4052a7 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Fri, 13 Oct 2017 15:13:06 -0700 Subject: [PATCH 060/712] [MINOR][SS] keyWithIndexToNumValues" -> "keyWithIndexToValue" ## What changes were proposed in this pull request? This PR changes `keyWithIndexToNumValues` to `keyWithIndexToValue`. There will be directories on HDFS named with this `keyWithIndexToNumValues`. So if we ever want to fix this, let's fix it now. ## How was this patch tested? existing unit test cases. Author: Liwei Lin Closes #19435 from lw-lin/keyWithIndex. --- .../streaming/state/SymmetricHashJoinStateManager.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index d256fb578d921..6b386308c79fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -384,7 +384,7 @@ class SymmetricHashJoinStateManager( } /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) { + private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) private val indexOrdinalInKeyWithIndexRow = keyAttributes.size @@ -471,7 +471,7 @@ class SymmetricHashJoinStateManager( object SymmetricHashJoinStateManager { def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValuesType) + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { getStateStoreName(joinSide, stateStoreType) } @@ -483,8 +483,8 @@ object SymmetricHashJoinStateManager { override def toString(): String = "keyToNumValues" } - private case object KeyWithIndexToValuesType extends StateStoreType { - override def toString(): String = "keyWithIndexToNumValues" + private case object KeyWithIndexToValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" } private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { From 06df34d35ec088277445ef09cfb24bfe996f072e Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Fri, 13 Oct 2017 17:12:50 -0700 Subject: [PATCH 061/712] [SPARK-11034][LAUNCHER][MESOS] Launcher: add support for monitoring Mesos apps ## What changes were proposed in this pull request? Added Launcher support for monitoring Mesos apps in Client mode. SPARK-11033 can handle the support for Mesos/Cluster mode since the Standalone/Cluster and Mesos/Cluster modes use the same code at client side. ## How was this patch tested? I verified it manually by running launcher application, able to launch, stop and kill the mesos applications and also can invoke other launcher API's. Author: Devaraj K Closes #19385 from devaraj-kavali/SPARK-11034. --- .../MesosCoarseGrainedSchedulerBackend.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 80c0a041b7322..603c980cb268d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskStat import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -89,6 +90,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = { + stopSchedulerBackend() + setState(SparkAppHandle.State.KILLED) + } + } + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -182,6 +190,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def start() { super.start() + if (sc.deployMode == "client") { + launcherBackend.connect() + } val startedBefore = IdHelper.startedBefore.getAndSet(true) val suffix = if (startedBefore) { @@ -202,6 +213,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.conf.getOption("spark.mesos.driver.frameworkId").map(_ + suffix) ) + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) startScheduler(driver) } @@ -295,15 +307,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( this.mesosExternalShuffleClient.foreach(_.init(appId)) this.schedulerDriver = driver markRegistered() + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } - override def disconnected(d: org.apache.mesos.SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) { + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) + } - override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) { + launcherBackend.setState(SparkAppHandle.State.RUNNING) + } /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, @@ -611,6 +629,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def stop() { + stopSchedulerBackend() + launcherBackend.setState(SparkAppHandle.State.FINISHED) + launcherBackend.close() + } + + private def stopSchedulerBackend() { // Make sure we're not launching tasks during shutdown stateLock.synchronized { if (stopCalled) { From e3536406ec6ff65a8b41ba2f2fd40517a760cfd6 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 13 Oct 2017 23:08:17 -0700 Subject: [PATCH 062/712] [SPARK-21762][SQL] FileFormatWriter/BasicWriteTaskStatsTracker metrics collection fails if a new file isn't yet visible ## What changes were proposed in this pull request? `BasicWriteTaskStatsTracker.getFileSize()` to catch `FileNotFoundException`, log info and then return 0 as a file size. This ensures that if a newly created file isn't visible due to the store not always having create consistency, the metric collection doesn't cause the failure. ## How was this patch tested? New test suite included, `BasicWriteTaskStatsTrackerSuite`. This not only checks the resilience to missing files, but verifies the existing logic as to how file statistics are gathered. Note that in the current implementation 1. if you call `Tracker..getFinalStats()` more than once, the file size count will increase by size of the last file. This could be fixed by clearing the filename field inside `getFinalStats()` itself. 2. If you pass in an empty or null string to `Tracker.newFile(path)` then IllegalArgumentException is raised, but only in `getFinalStats()`, rather than in `newFile`. There's a test for this behaviour in the new suite, as it verifies that only FNFEs get swallowed. Author: Steve Loughran Closes #18979 from steveloughran/cloud/SPARK-21762-missing-files-in-metrics. --- .../datasources/BasicWriteStatsTracker.scala | 49 +++- .../BasicWriteTaskStatsTrackerSuite.scala | 220 ++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 8 + 3 files changed, 265 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b8f7d130d569f..11af0aaa7b206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -44,20 +47,32 @@ case class BasicWriteTaskStats( * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) - extends WriteTaskStatsTracker { + extends WriteTaskStatsTracker with Logging { private[this] var numPartitions: Int = 0 private[this] var numFiles: Int = 0 + private[this] var submittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: String = null - + private[this] var curFile: Option[String] = None - private def getFileSize(filePath: String): Long = { + /** + * Get the size of the file expected to have been written by a worker. + * @param filePath path to the file + * @return the file size or None if the file was not found. + */ + private def getFileSize(filePath: String): Option[Long] = { val path = new Path(filePath) val fs = path.getFileSystem(hadoopConf) - fs.getFileStatus(path).getLen() + try { + Some(fs.getFileStatus(path).getLen()) + } catch { + case e: FileNotFoundException => + // may arise against eventually consistent object stores + logDebug(s"File $path is not yet visible", e) + None + } } @@ -70,12 +85,19 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def newFile(filePath: String): Unit = { - if (numFiles > 0) { - // we assume here that we've finished writing to disk the previous file by now - numBytes += getFileSize(curFile) + statCurrentFile() + curFile = Some(filePath) + submittedFiles += 1 + } + + private def statCurrentFile(): Unit = { + curFile.foreach { path => + getFileSize(path).foreach { len => + numBytes += len + numFiles += 1 + } + curFile = None } - curFile = filePath - numFiles += 1 } override def newRow(row: InternalRow): Unit = { @@ -83,8 +105,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - if (numFiles > 0) { - numBytes += getFileSize(curFile) + statCurrentFile() + if (submittedFiles != numFiles) { + logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + "This could be due to the output format not writing empty files, " + + "or files being not immediately visible in the filesystem.") } BasicWriteTaskStats(numPartitions, numFiles, numBytes, numRows) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala new file mode 100644 index 0000000000000..bf3c8ede9a980 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.nio.charset.Charset + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +/** + * Test how BasicWriteTaskStatsTracker handles files. + * + * Two different datasets are written (alongside 0), one of + * length 10, one of 3. They were chosen to be distinct enough + * that it is straightforward to determine which file lengths were added + * from the sum of all files added. Lengths like "10" and "5" would + * be less informative. + */ +class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { + + private val tempDir = Utils.createTempDir() + private val tempDirPath = new Path(tempDir.toURI) + private val conf = new Configuration() + private val localfs = tempDirPath.getFileSystem(conf) + private val data1 = "0123456789".getBytes(Charset.forName("US-ASCII")) + private val data2 = "012".getBytes(Charset.forName("US-ASCII")) + private val len1 = data1.length + private val len2 = data2.length + + /** + * In teardown delete the temp dir. + */ + protected override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + } + + /** + * Assert that the stats match that expected. + * @param tracker tracker to check + * @param files number of files expected + * @param bytes total number of bytes expected + */ + private def assertStats( + tracker: BasicWriteTaskStatsTracker, + files: Int, + bytes: Int): Unit = { + val stats = finalStatus(tracker) + assert(files === stats.numFiles, "Wrong number of files") + assert(bytes === stats.numBytes, "Wrong byte count of file size") + } + + private def finalStatus(tracker: BasicWriteTaskStatsTracker): BasicWriteTaskStats = { + tracker.getFinalStats().asInstanceOf[BasicWriteTaskStats] + } + + test("No files in run") { + val tracker = new BasicWriteTaskStatsTracker(conf) + assertStats(tracker, 0, 0) + } + + test("Missing File") { + val missing = new Path(tempDirPath, "missing") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(missing.toString) + assertStats(tracker, 0, 0) + } + + test("Empty filename is forwarded") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile("") + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("Null filename is only picked up in final status") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(null) + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("0 byte file") { + val file = new Path(tempDirPath, "file0") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + touch(file) + assertStats(tracker, 1, 0) + } + + test("File with data") { + val file = new Path(tempDirPath, "file-with-data") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + write1(file) + assertStats(tracker, 1, len1) + } + + test("Open file") { + val file = new Path(tempDirPath, "file-open") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + val stream = localfs.create(file, true) + try { + assertStats(tracker, 1, 0) + stream.write(data1) + stream.flush() + assert(1 === finalStatus(tracker).numFiles, "Wrong number of files") + } finally { + stream.close() + } + } + + test("Two files") { + val file1 = new Path(tempDirPath, "f-2-1") + val file2 = new Path(tempDirPath, "f-2-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + assertStats(tracker, 2, len1 + len2) + } + + test("Three files, last one empty") { + val file1 = new Path(tempDirPath, "f-3-1") + val file2 = new Path(tempDirPath, "f-3-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + tracker.newFile(file3.toString) + touch(file3) + assertStats(tracker, 3, len1 + len2) + } + + test("Three files, one not found") { + val file1 = new Path(tempDirPath, "f-4-1") + val file2 = new Path(tempDirPath, "f-4-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + // file 1 + tracker.newFile(file1.toString) + write1(file1) + + // file 2 is noted, but not created + tracker.newFile(file2.toString) + + // file 3 is noted & then created + tracker.newFile(file3.toString) + write2(file3) + + // the expected size is file1 + file3; only two files are reported + // as found + assertStats(tracker, 2, len1 + len2) + } + + /** + * Write a 0-byte file. + * @param file file path + */ + private def touch(file: Path): Unit = { + localfs.create(file, true).close() + } + + /** + * Write a byte array. + * @param file path to file + * @param data data + * @return bytes written + */ + private def write(file: Path, data: Array[Byte]): Integer = { + val stream = localfs.create(file, true) + try { + stream.write(data) + } finally { + stream.close() + } + data.length + } + + /** + * Write a data1 array. + * @param file file + */ + private def write1(file: Path): Unit = { + write(file, data1) + } + + /** + * Write a data2 array. + * + * @param file file + */ + private def write2(file: Path): Unit = { + write(file, data2) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 94fa43dec7313..60935c3e85c43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2110,4 +2110,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempDir { dir => + Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") + } + } + } } From e0503a7223410289d01bc4b20da3a451730577da Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 13 Oct 2017 23:24:36 -0700 Subject: [PATCH 063/712] [SPARK-22273][SQL] Fix key/value schema field names in HashMapGenerators. ## What changes were proposed in this pull request? When fixing schema field names using escape characters with `addReferenceMinorObj()` at [SPARK-18952](https://issues.apache.org/jira/browse/SPARK-18952) (#16361), double-quotes around the names were remained and the names become something like `"((java.lang.String) references[1])"`. ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[1])", org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add("((java.lang.String) references[2])", org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` We should remove the double-quotes to refer the values in `references` properly: ```java /* 055 */ private int maxSteps = 2; /* 056 */ private int numRows = 0; /* 057 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1]), org.apache.spark.sql.types.DataTypes.StringType); /* 058 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2]), org.apache.spark.sql.types.DataTypes.LongType); /* 059 */ private Object emptyVBase; ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #19491 from ueshin/issues/SPARK-22273. --- .../execution/aggregate/RowBasedHashMapGenerator.scala | 8 ++++---- .../execution/aggregate/VectorizedHashMapGenerator.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 9316ebcdf105c..3718424931b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -50,10 +50,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -63,10 +63,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 13f79275cac41..812d405d5ebfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -55,10 +55,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -68,10 +68,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") From 014dc8471200518d63005eed531777d30d8a6639 Mon Sep 17 00:00:00 2001 From: liulijia Date: Sat, 14 Oct 2017 17:37:33 +0900 Subject: [PATCH 064/712] [SPARK-22233][CORE] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Add a flag spark.files.ignoreEmptySplits. When true, methods like that use HadoopRDD and NewHadoopRDD such as SparkContext.textFiles will not create a partition for input splits that are empty. Author: liulijia Closes #19464 from liutang123/SPARK-22233. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 12 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 13 ++- .../scala/org/apache/spark/FileSuite.scala | 95 +++++++++++++++++-- 4 files changed, 112 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 19336f854145f..ce013d69579c1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,6 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) + private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") + .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + + "SparkContext.textFiles will not create a partition for input splits that are empty.") + .booleanConf + .createWithDefault(false) + private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 23b344230e490..1f33c0a2b709f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,6 +134,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -195,8 +197,12 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val inputFormat = getInputFormat(jobConf) - val inputSplits = inputFormat.getSplits(jobConf, minPartitions) + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 482875e6c1ac5..db4eac1d0a775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -21,6 +21,7 @@ import java.io.IOException import java.text.SimpleDateFormat import java.util.{Date, Locale} +import scala.collection.JavaConverters.asScalaBufferConverter import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -34,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -89,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -121,8 +124,12 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = new JobContextImpl(_conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 02728180ac82d..4da4323ceb5c8 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -347,10 +347,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test ("allow user to disable the output directory existence checking (old Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + test ("allow user to disable the output directory existence checking (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) randomRDD.saveAsTextFile(tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) @@ -380,9 +380,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test ("allow user to disable the output directory existence checking (new Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( @@ -510,4 +510,83 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsHadoopFile[TextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-0000$i").exists() === true) + } + val hadoopRDD = sc.textFile(new File(output, "part-*").getPath) + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a"), ("key3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } + + test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-r-0000$i").exists() === true) + } + val hadoopRDD = sc.newAPIHadoopFile(new File(output, "part-r-*").getPath, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "a"), ("3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "b")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } } From e8547ffb49071525c06876c856cecc0d4731b918 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 14 Oct 2017 17:39:15 -0700 Subject: [PATCH 065/712] [SPARK-22238] Fix plan resolution bug caused by EnsureStatefulOpPartitioning ## What changes were proposed in this pull request? In EnsureStatefulOpPartitioning, we check that the inputRDD to a SparkPlan has the expected partitioning for Streaming Stateful Operators. The problem is that we are not allowed to access this information during planning. The reason we added that check was because CoalesceExec could actually create RDDs with 0 partitions. We should fix it such that when CoalesceExec says that there is a SinglePartition, there is in fact an inputRDD of 1 partition instead of 0 partitions. ## How was this patch tested? Regression test in StreamingQuerySuite Author: Burak Yavuz Closes #19467 from brkyvz/stateful-op. --- .../plans/physical/partitioning.scala | 15 +- .../execution/basicPhysicalOperators.scala | 27 +++- .../exchange/EnsureRequirements.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../streaming/IncrementalExecution.scala | 39 ++--- .../streaming/statefulOperators.scala | 11 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 17 +++ .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../SymmetricHashJoinStateManagerSuite.scala | 2 +- .../sql/streaming/DeduplicateSuite.scala | 11 +- .../EnsureStatefulOpPartitioningSuite.scala | 138 ------------------ .../FlatMapGroupsWithStateSuite.scala | 6 +- .../sql/streaming/StatefulOperatorTest.scala | 49 +++++++ .../streaming/StreamingAggregationSuite.scala | 8 +- .../sql/streaming/StreamingJoinSuite.scala | 2 +- .../sql/streaming/StreamingQuerySuite.scala | 13 ++ 17 files changed, 160 insertions(+), 189 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..e57c842ce2a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) case _ => true } @@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } @@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 63cd1691f4cd7..d15ece304cac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -590,10 +590,33 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - child.execute().coalesce(numPartitions, shuffle = false) + if (numPartitions == 1 && child.execute().getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + child.execute().coalesce(numPartitions, shuffle = false) + } } } +object CoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ + class EmptyRDDWithPartitions( + @transient private val sc: SparkContext, + numPartitions: Int) extends RDD[InternalRow](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + Iterator.empty + } + } + + case class EmptyPartition(index: Int) extends Partition +} + /** * Physical plan for a subquery. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d28ce60e276d5..4e2ca37bc1a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -44,13 +44,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { /** * Given a required distribution, returns a partitioning that satisfies that distribution. + * @param requiredDistribution The distribution that is required by the operator + * @param numPartitions Used when the distribution doesn't require a specific number of partitions */ private def createPartitioning( requiredDistribution: Distribution, numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering, desiredPartitions) => + HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index aab06d611a5ea..c81f1a8142784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 82f879c763c2b..2e378637727fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -61,6 +62,10 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } + private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** * See [SPARK-18339] * Walk the optimized logical plan and replace CurrentBatchTimestamp @@ -83,7 +88,11 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { StatefulOperatorStateInfo( - checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) } /** Locates save/restore pairs surrounding aggregation. */ @@ -130,34 +139,8 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } - -object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { - // Needs to be transformUp to avoid extra shuffles - override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case so: StatefulOperator => - val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions - val distributions = so.requiredChildDistribution - val children = so.children.zip(distributions).map { case (child, reqDistribution) => - val expectedPartitioning = reqDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) - case _ => throw new AnalysisException("Unexpected distribution expected for " + - s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + - s"$reqDistribution.") - } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { - child - } else { - ShuffleExchangeExec(expectedPartitioning, child) - } - } - so.withNewChildren(children) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 0d85542928ee6..b9b07a2e688f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, - storeVersion: Long) { + storeVersion: Long, + numPartitions: Int) { override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" } } @@ -239,7 +240,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -386,7 +387,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -401,7 +402,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad461fa6144b3..50de2fd3bca8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).coalesce(1).select('key), testData.select('key).collect().toSeq) + + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 86066362da9dd..c25c90d0c70e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } + assert(shuffle.size === 1) + assert(shuffle.head.newPartitioning === finalPartitioning) + } + test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index defb9ed63a881..65b39f0fbd73d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index d44af1d14c27a..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index e858b7d9998a8..caf2bab8a5859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class DeduplicateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ @@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala deleted file mode 100644 index ed9823fbddfda..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} -import org.apache.spark.sql.test.SharedSQLContext - -class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { - - import testImplicits._ - - private var baseDf: DataFrame = null - - override def beforeAll(): Unit = { - super.beforeAll() - baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") - } - - test("ClusteredDistribution generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("AllTuples generates Exchange with SinglePartition") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = true) - } - - test("AllTuples with coalesce(1) doesn't need Exchange") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = false) - } - - /** - * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan - * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to - * ensure the expected partitioning. - */ - private def testEnsureStatefulOpPartitioning( - inputPlan: SparkPlan, - requiredDistribution: Seq[Attribute] => Distribution, - expectedPartitioning: Seq[Attribute] => Partitioning, - expectShuffle: Boolean): Unit = { - val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) - val executed = executePlan(operator, OutputMode.Complete()) - if (expectShuffle) { - val exchange = executed.children.find(_.isInstanceOf[Exchange]) - if (exchange.isEmpty) { - fail(s"Was expecting an exchange but didn't get one in:\n$executed") - } - assert(exchange.get === - ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), - s"Exchange didn't have expected properties:\n${exchange.get}") - } else { - assert(!executed.children.exists(_.isInstanceOf[Exchange]), - s"Unexpected exchange found in:\n$executed") - } - } - - /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ - private def executePlan( - p: SparkPlan, - outputMode: OutputMode = OutputMode.Append()): SparkPlan = { - val execution = new IncrementalExecution( - spark, - null, - OutputMode.Complete(), - "chk", - UUID.randomUUID(), - 0L, - OffsetSeqMetadata()) { - override lazy val sparkPlan: SparkPlan = p transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - execution.executedPlan - } -} - -/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ -case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index d2e8beb2f5290..aeb83835f981a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -41,7 +41,9 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ import GroupStateImpl._ @@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( + sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala new file mode 100644 index 0000000000000..45142278993bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.streaming._ + +trait StatefulOperatorTest { + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + colNames: Seq[String]): Boolean = { + val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output + val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions + val groupingAttr = attr.filter(a => colNames.contains(a.name)) + checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) + } + + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + expectedPartitioning: Partitioning): Boolean = { + val operator = sq.asInstanceOf[StreamExecution].lastExecution + .executedPlan.collect { case p: T => p } + operator.head.children.forall( + _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fe7efa69f7e31..1b4d8556f6ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { + with BeforeAndAfterAll with Assertions with StatefulOperatorTest { override def afterAll(): Unit = { super.afterAll() @@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), @@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) + AssertOnQuery("Verify that no exchange is required") { se => + checkAggregationChain(se, expectShuffling = false, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a6593b71e51de..d32617275aadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with val queryId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ab35079dca23f..c53889bb8566c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") { + val stream = MemoryStream[(Int, Int)] + val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") + val otherDf = stream.toDF().toDF("num", "numSq") + .join(broadcast(baseDf), "num") + .groupBy('char) + .agg(sum('numSq)) + + testStream(otherDf, OutputMode.Complete())( + AddData(stream, (1, 1), (2, 4)), + CheckLastBatch(("A", 1))) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 13c1559587d0eb533c94f5a492390f81b048b347 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 15 Oct 2017 18:40:53 -0700 Subject: [PATCH 066/712] [SPARK-21549][CORE] Respect OutputFormats with no/invalid output directory provided ## What changes were proposed in this pull request? PR #19294 added support for null's - but spark 2.1 handled other error cases where path argument can be invalid. Namely: * empty string * URI parse exception while creating Path This is resubmission of PR #19487, which I messed up while updating my repo. ## How was this patch tested? Enhanced test to cover new support added. Author: Mridul Muralidharan Closes #19497 from mridulm/master. --- .../io/HadoopMapReduceCommitProtocol.scala | 24 +++++++------- .../spark/rdd/PairRDDFunctionsSuite.scala | 31 +++++++++++++------ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index a7e6859ef6b64..95c99d29c3a9c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import java.util.{Date, UUID} import scala.collection.mutable +import scala.util.Try import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -47,6 +48,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Checks whether there are files to be committed to a valid output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether a valid output path is specified. + * [[HadoopMapReduceCommitProtocol#path]] need not be a valid [[org.apache.hadoop.fs.Path]] for + * committers not writing to distributed file systems. + */ + private val hasValidPath = Try { new Path(path) }.isSuccess + /** * Tracks files staged by this task for absolute output paths. These outputs are not managed by * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. @@ -60,15 +71,6 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) */ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) - /** - * Checks whether there are files to be committed to an absolute output location. - * - * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, - * it is necessary to check whether the output path is specified. Output path may not be required - * for committers not writing to distributed file systems. - */ - private def hasAbsPathFiles: Boolean = path != null - protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -142,7 +144,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) @@ -153,7 +155,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - if (hasAbsPathFiles) { + if (hasValidPath) { val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) fs.delete(absPathStagingDir, true) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 07579c5098014..0a248b6064ee8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -568,21 +568,34 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } - test("saveAsNewAPIHadoopDataset should respect empty output directory when " + + test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) - val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) - job.setOutputKeyClass(classOf[Integer]) - job.setOutputValueClass(classOf[Integer]) - job.setOutputFormatClass(classOf[NewFakeFormat]) - val jobConfiguration = job.getConfiguration + def saveRddWithPath(path: String): Unit = { + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + if (null != path) { + job.getConfiguration.set("mapred.output.dir", path) + } else { + job.getConfiguration.unset("mapred.output.dir") + } + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with java.lang.IllegalArgumentException. + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } - // just test that the job does not fail with - // java.lang.IllegalArgumentException: Can not create a Path from a null string - pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + saveRddWithPath(null) + saveRddWithPath("") + saveRddWithPath("::invalid::") } + // In spark 2.1, only null was supported - not other invalid paths. + // org.apache.hadoop.mapred.FileOutputFormat.getOutputPath fails with IllegalArgumentException + // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) From 0ae96495dedb54b3b6bae0bd55560820c5ca29a2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Oct 2017 13:37:58 +0800 Subject: [PATCH 067/712] [SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle ## What changes were proposed in this pull request? `ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle. ## How was this patch tested? Added Jenkins test. Author: Liang-Chi Hsieh Closes #19501 from viirya/SPARK-22223. --- .../aggregate/ObjectHashAggregateExec.scala | 2 ++ .../spark/sql/DataFrameAggregateSuite.scala | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index ec3f9a05b5ccc..66955b8ef723c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -95,6 +95,8 @@ case class ObjectHashAggregateExec( } } + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8549eac58ee95..06848e4d2b297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + + val objHashAggDF = df + .withColumn("d", expr("(a, b, c)")) + .groupBy("a", "b").agg(collect_list("d").as("e")) + .withColumn("f", expr("(b, e)")) + .groupBy("a").agg(collect_list("f").as("g")) + val aggPlan = objHashAggDF.queryExecution.executedPlan + + val sortAggPlans = aggPlan.collect { + case sortAgg: SortAggregateExec => sortAgg + } + assert(sortAggPlans.isEmpty) + + val objHashAggPlans = aggPlan.collect { + case objHashAgg: ObjectHashAggregateExec => objHashAgg + } + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = aggPlan.collect { + case shuffle: ShuffleExchangeExec => shuffle + } + assert(exchangePlans.length == 1) + } + } } From 0fa10666cf75e3c4929940af49c8a6f6ea874759 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 16 Oct 2017 22:15:50 +0800 Subject: [PATCH 068/712] [SPARK-22233][CORE][FOLLOW-UP] Allow user to filter out empty split in HadoopRDD ## What changes were proposed in this pull request? Update the config `spark.files.ignoreEmptySplits`, rename it and make it internal. This is followup of #19464 ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #19504 from jiangxb1987/partitionsplit. --- .../org/apache/spark/internal/config/package.scala | 11 ++++++----- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- .../test/scala/org/apache/spark/FileSuite.scala | 14 +++++++++----- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ce013d69579c1..efffdca1ea59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -270,11 +270,12 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) - private[spark] val IGNORE_EMPTY_SPLITS = ConfigBuilder("spark.files.ignoreEmptySplits") - .doc("If true, methods that use HadoopRDD and NewHadoopRDD such as " + - "SparkContext.textFiles will not create a partition for input splits that are empty.") - .booleanConf - .createWithDefault(false) + private[spark] val HADOOP_RDD_IGNORE_EMPTY_SPLITS = + ConfigBuilder("spark.hadoopRDD.ignoreEmptySplits") + .internal() + .doc("When true, HadoopRDD/NewHadoopRDD will not create partitions for empty input splits.") + .booleanConf + .createWithDefault(false) private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 1f33c0a2b709f..2480559a41b7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,7 +134,7 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index db4eac1d0a775..e4dd1b6a82498 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -90,7 +90,7 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) - private val ignoreEmptySplits = sparkContext.getConf.get(IGNORE_EMPTY_SPLITS) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4da4323ceb5c8..e9539dc73f6fa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.{IGNORE_CORRUPT_FILES, IGNORE_EMPTY_SPLITS} +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -510,9 +510,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test("spark.files.ignoreEmptySplits work correctly (old Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (old Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( @@ -549,9 +551,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { expectedPartitionNum = 2) } - test("spark.files.ignoreEmptySplits work correctly (new Hadoop API)") { + test("spark.hadoopRDD.ignoreEmptySplits work correctly (new Hadoop API)") { val conf = new SparkConf() - conf.setAppName("test").setMaster("local").set(IGNORE_EMPTY_SPLITS, true) + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) sc = new SparkContext(conf) def testIgnoreEmptySplits( From 561505e2fc290fc2cee3b8464ec49df773dca5eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 11:27:08 -0700 Subject: [PATCH 069/712] [SPARK-22282][SQL] Rename OrcRelation to OrcFileFormat and remove ORC_COMPRESSION ## What changes were proposed in this pull request? This PR aims to - Rename `OrcRelation` to `OrcFileFormat` object. - Replace `OrcRelation.ORC_COMPRESSION` with `org.apache.orc.OrcConf.COMPRESS`. Since [SPARK-21422](https://issues.apache.org/jira/browse/SPARK-21422), we can use `OrcConf.COMPRESS` instead of Hive's. ```scala // The references of Hive's classes will be minimized. val ORC_COMPRESSION = "orc.compress" ``` ## How was this patch tested? Pass the Jenkins with the existing and updated test cases. Author: Dongjoon Hyun Closes #19502 from dongjoon-hyun/SPARK-22282. --- .../org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 18 ++++++++---------- .../apache/spark/sql/hive/orc/OrcOptions.scala | 8 +++++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 11 ++++++----- .../spark/sql/hive/orc/OrcSourceSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..c9e45436ed42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -520,8 +520,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default is the value specified in `spark.sql.orc.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(`none`, `snappy`, `zlib`, and `lzo`). This will override - * `orc.compress` and `spark.sql.parquet.compression.codec`. If `orc.compress` is given, - * it overrides `spark.sql.parquet.compression.codec`.
  • + * `orc.compress` and `spark.sql.orc.compression.codec`. If `orc.compress` is given, + * it overrides `spark.sql.orc.compression.codec`. * * * @since 1.5.0 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 194e69c93e1a8..d26ec15410d95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.orc.OrcConf.COMPRESS import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession @@ -72,7 +73,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -93,8 +94,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -120,7 +121,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcRelation.SARG_PUSHDOWN, f.toKryo) + hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -138,7 +139,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (isEmptyFile) { Iterator.empty } else { - OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema) + OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -160,7 +161,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( + OrcFileFormat.unwrapOrcStructs( conf, dataSchema, requiredSchema, @@ -255,10 +256,7 @@ private[orc] class OrcOutputWriter( } } -private[orc] object OrcRelation extends HiveInspectors { - // The references of Hive's classes will be minimized. - val ORC_COMPRESSION = "orc.compress" - +private[orc] object OrcFileFormat extends HiveInspectors { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 7f94c8c579026..6ce90c07b4921 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -40,9 +42,9 @@ private[orc] class OrcOptions( * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are - // in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` + // are in order of precedence from highest to lowest. + val orcCompressionConf = parameters.get(COMPRESS.getAttribute) val codecName = parameters .get("compression") .orElse(orcCompressionConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 60ccd996d6d58..1fa9091f967a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ @@ -176,11 +177,11 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("SPARK-16610: Respect orc.compress option when compression is unset") { - // Respect `orc.compress`. + test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") { + // Respect `orc.compress` (i.e., OrcConf.COMPRESS). withTempPath { file => spark.range(0, 10).write - .option("orc.compress", "ZLIB") + .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -191,7 +192,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => spark.range(0, 10).write .option("compression", "ZLIB") - .option("orc.compress", "SNAPPY") + .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -598,7 +599,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 781de6631f324..ef9e67c743837 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} @@ -150,7 +152,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = sqlContext.sessionState.conf - assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE") + val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) + assert(option.compressionCodec == "NONE") } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { @@ -205,8 +208,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") - val map1 = Map("orc.compress" -> "zlib") - val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo") + val map1 = Map(COMPRESS.getAttribute -> "zlib") + val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") assert(new OrcOptions(map2, conf).compressionCodec == "LZO") } From c09a2a76b52905a784d2767cb899dc886c330628 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 16 Oct 2017 16:16:34 -0700 Subject: [PATCH 070/712] [SPARK-22280][SQL][TEST] Improve StatisticsSuite to test `convertMetastore` properly ## What changes were proposed in this pull request? This PR aims to improve **StatisticsSuite** to test `convertMetastore` configuration properly. Currently, some test logic in `test statistics of LogicalRelation converted from Hive serde tables` depends on the default configuration. New test case is shorter and covers both(true/false) cases explicitly. This test case was previously modified by SPARK-17410 and SPARK-17284 in Spark 2.3.0. - https://github.com/apache/spark/commit/a2460be9c30b67b9159fe339d115b84d53cc288a#diff-1c464c86b68c2d0b07e73b7354e74ce7R443 ## How was this patch tested? Pass the Jenkins with the improved test case. Author: Dongjoon Hyun Closes #19500 from dongjoon-hyun/SPARK-22280. --- .../spark/sql/hive/StatisticsSuite.scala | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9ff9ecf7f3677..b9a5ad7657134 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -937,26 +937,20 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("test statistics of LogicalRelation converted from Hive serde tables") { - val parquetTable = "parquetTable" - val orcTable = "orcTable" - withTable(parquetTable, orcTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) STORED AS PARQUET") - sql(s"CREATE TABLE $orcTable (key STRING, value STRING) STORED AS ORC") - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"INSERT INTO TABLE $orcTable SELECT * FROM src") - - // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it - // for robustness - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { - checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) - } - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - // We still can get tableSize from Hive before Analyze - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + withTable(format) { + sql(s"CREATE TABLE $format (key STRING, value STRING) STORED AS $format") + sql(s"INSERT INTO TABLE $format SELECT * FROM src") + + checkTableStats(format, hasSizeInBytes = !isConverted, expectedRowCounts = None) + sql(s"ANALYZE TABLE $format COMPUTE STATISTICS") + checkTableStats(format, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } + } } } } From e66cabb0215204605ca7928406d4787d41853dd1 Mon Sep 17 00:00:00 2001 From: Ben Barnard Date: Tue, 17 Oct 2017 09:36:09 +0200 Subject: [PATCH 071/712] [SPARK-20992][SCHEDULER] Add links in documentation to Nomad integration. ## What changes were proposed in this pull request? Adds links to the fork that provides integration with Nomad, in the same places the k8s integration is linked to. ## How was this patch tested? I clicked on the links to make sure they're correct ;) Author: Ben Barnard Closes #19354 from barnardb/link-to-nomad-integration. --- docs/cluster-overview.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index a2ad958959a50..c42bb4bb8377e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -58,6 +58,9 @@ for providing container-centric infrastructure. Kubernetes support is being acti developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. For documentation, refer to that project's README. +A third-party project (not supported by the Spark project) exists to add support for +[Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. + # Submitting Applications Applications can be submitted to a cluster of any type using the `spark-submit` script. From 8148f19ca1f0e0375603cb4f180c1bad8b0b8042 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Oct 2017 09:41:23 +0200 Subject: [PATCH 072/712] [SPARK-22249][SQL] isin with empty list throws exception on cached DataFrame ## What changes were proposed in this pull request? As pointed out in the JIRA, there is a bug which causes an exception to be thrown if `isin` is called with an empty list on a cached DataFrame. The PR fixes it. ## How was this patch tested? Added UT. Author: Marco Gaido Closes #19494 from mgaido91/SPARK-22249. --- .../columnar/InMemoryTableScanExec.scala | 1 + .../columnar/InMemoryColumnarQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index af3636a5a2ca7..846ec03e46a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,6 +102,7 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 + case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 8d411eb191cd9..75d17bc79477d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -429,4 +429,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(agg_without_cache, agg_with_cache) } } + + test("SPARK-22249: IN should work also with cached DataFrame") { + val df = spark.range(10).cache() + // with an empty list + assert(df.filter($"id".isin()).count() == 0) + // with a non-empty list + assert(df.filter($"id".isin(2)).count() == 1) + assert(df.filter($"id".isin(2, 3)).count() == 2) + df.unpersist() + val dfNulls = spark.range(10).selectExpr("null as id").cache() + // with null as value for the attribute + assert(dfNulls.filter($"id".isin()).count() == 0) + assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) + dfNulls.unpersist() + } } From 99e32f8ba5d908d5408e9857fd96ac1d7d7e5876 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 17 Oct 2017 17:58:45 +0800 Subject: [PATCH 073/712] [SPARK-22224][SQL] Override toString of KeyValue/Relational-GroupedDataset ## What changes were proposed in this pull request? #### before ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = org.apache.spark.sql.KeyValueGroupedDataset65214862 ``` #### after ```scala scala> val words = spark.read.textFile("README.md").flatMap(_.split(" ")) words: org.apache.spark.sql.Dataset[String] = [value: string] scala> val grouped = words.groupByKey(identity) grouped: org.apache.spark.sql.KeyValueGroupedDataset[String,String] = [key: [value: string], value: [value: string]] ``` ## How was this patch tested? existing ut cc gatorsmile cloud-fan Author: Kent Yao Closes #19363 from yaooqinn/minor-dataset-tostring. --- .../spark/sql/KeyValueGroupedDataset.scala | 22 ++++++- .../spark/sql/RelationalGroupedDataset.scala | 19 +++++- .../org/apache/spark/sql/DatasetSuite.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 12 +--- 4 files changed, 100 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cb42e9e4560cf..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,7 +24,6 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} @@ -564,4 +563,25 @@ class KeyValueGroupedDataset[K, V] private[sql]( encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } + + override def toString: String = { + val builder = new StringBuilder + val kFields = kExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + val vFields = vExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("KeyValueGroupedDataset: [key: [") + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append("], value: [") + builder.append(vFields.take(2).mkString(", ")) + if (vFields.length > 2) { + builder.append(" ... " + (vFields.length - 2) + " more field(s)") + } + builder.append("]]").toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index cd0ac1feffa51..33ec3a27110a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructField, StructType} +import org.apache.spark.sql.types.{NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -465,6 +465,19 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + + override def toString: String = { + val builder = new StringBuilder + builder.append("RelationalGroupedDataset: [grouping expressions: [") + val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append(s"], value: ${df.toString}, type: $groupType]").toString() + } } private[sql] object RelationalGroupedDataset { @@ -479,7 +492,9 @@ private[sql] object RelationalGroupedDataset { /** * The Grouping Type */ - private[sql] trait GroupType + private[sql] trait GroupType { + override def toString: String = getClass.getSimpleName.stripSuffix("$").stripSuffix("Type") + } /** * To indicate it's the GroupBy diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dace6825ee40e..1537ce3313c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1341,8 +1341,69 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2)))) } + + test("Check RelationalGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").groupBy("id") + val expected = "RelationalGroupedDataset: [" + + "grouping expressions: [id: int], value: [id: int], type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check RelationalGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").groupBy("id") + val expected = "RelationalGroupedDataset:" + + " [grouping expressions: [id: int]," + + " value: [id: int, val1: string ... 1 more field]," + + " type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + + test("Check KeyValueGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").as[SingleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset: [key: [id: int], value: [id: int]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Unnamed KV-pair") { + val kvDataset = (1 to 3).map(x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => (x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [_1: int, _2: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Named KV-pair") { + val kvDataset = (1 to 3).map( x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => DoubleData(x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").as[TripleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string ... 1 more field(s)]," + + " value: [id: int, val1: string ... 1 more field(s)]]" + val actual = kvDataset.toString + assert(expected === actual) + } } +case class SingleData(id: Int) +case class DoubleData(id: Int, val1: String) +case class TripleData(id: Int, val1: String, val2: Long) + case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..fcaca3d75b74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,23 +17,13 @@ package org.apache.spark.sql -import java.util.{ArrayDeque, Locale, TimeZone} +import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { From e1960c3d6f380b0dfbba6ee5d8ac6da4bc29a698 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Oct 2017 22:54:38 +0800 Subject: [PATCH 074/712] [SPARK-22062][CORE] Spill large block to disk in BlockManager's remote fetch to avoid OOM ## What changes were proposed in this pull request? In the current BlockManager's `getRemoteBytes`, it will call `BlockTransferService#fetchBlockSync` to get remote block. In the `fetchBlockSync`, Spark will allocate a temporary `ByteBuffer` to store the whole fetched block. This will potentially lead to OOM if block size is too big or several blocks are fetched simultaneously in this executor. So here leveraging the idea of shuffle fetch, to spill the large block to local disk before consumed by upstream code. The behavior is controlled by newly added configuration, if block size is smaller than the threshold, then this block will be persisted in memory; otherwise it will first spill to disk, and then read from disk file. To achieve this feature, what I did is: 1. Rename `TempShuffleFileManager` to `TempFileManager`, since now it is not only used by shuffle. 2. Add a new `TempFileManager` to manage the files of fetched remote blocks, the files are tracked by weak reference, will be deleted when no use at all. ## How was this patch tested? This was tested by adding UT, also manual verification in local test to perform GC to clean the files. Author: jerryshao Closes #19476 from jerryshao/SPARK-22062. --- .../shuffle/ExternalShuffleClient.java | 4 +- .../shuffle/OneForOneBlockFetcher.java | 12 +-- .../spark/network/shuffle/ShuffleClient.java | 10 +- ...eFileManager.java => TempFileManager.java} | 12 +-- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../spark/internal/config/package.scala | 15 +-- .../spark/network/BlockTransferService.scala | 28 +++-- .../netty/NettyBlockTransferService.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 102 ++++++++++++++++-- .../spark/storage/BlockManagerMaster.scala | 6 ++ .../storage/BlockManagerMasterEndpoint.scala | 14 +++ .../spark/storage/BlockManagerMessages.scala | 7 ++ .../storage/ShuffleBlockFetcherIterator.scala | 8 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 57 +++++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 10 +- docs/configuration.md | 11 +- 18 files changed, 236 insertions(+), 74 deletions(-) rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/{TempShuffleFileManager.java => TempFileManager.java} (74%) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 77702447edb88..510017fee2db5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,7 +91,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +99,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempShuffleFileManager).start(); + blockIds1, listener1, conf, tempFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 66b67e282c80d..3f2f20b4149f1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -58,7 +58,7 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; - private final TempShuffleFileManager tempShuffleFileManager; + private final TempFileManager tempFileManager; private StreamHandle streamHandle = null; @@ -79,14 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tempShuffleFileManager = tempShuffleFileManager; + this.tempFileManager = tempFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -125,7 +125,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempShuffleFileManager != null) { + if (tempFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -164,7 +164,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tempShuffleFileManager.createTempShuffleFile(); + this.targetFile = tempFileManager.createTempFile(); this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); this.chunkIndex = chunkIndex; } @@ -180,7 +180,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + if (!tempFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 5bd4412b75275..18b04fedcac5b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,10 +43,10 @@ public void init(String appId) { } * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. - * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. + * @param tempFileManager TempFileManager to create and clean temp files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -54,7 +54,7 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager); + TempFileManager tempFileManager); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java similarity index 74% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java index 84a5ed6a276bd..552364d274f19 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java @@ -20,17 +20,17 @@ import java.io.File; /** - * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * A manager to create temp block files to reduce the memory usage and also clean temp * files when they won't be used any more. */ -public interface TempShuffleFileManager { +public interface TempFileManager { - /** Create a temp shuffle block file. */ - File createTempShuffleFile(); + /** Create a temp block file. */ + File createTempFile(); /** - * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * Register a temp file to clean up when it won't be used any more. Return whether the * file is registered successfully. If `false`, the caller should clean up the file by itself. */ - boolean registerTempShuffleFileToClean(File file); + boolean registerTempFileToClean(File file); } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e61f943af49f2..57b3744e9c30a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -662,7 +662,9 @@ private[spark] object SparkConf extends Logging { "spark.yarn.jars" -> Seq( AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( - AlternateConfig("spark.yarn.access.namenodes", "2.2")) + AlternateConfig("spark.yarn.access.namenodes", "2.2")), + "spark.maxRemoteBlockSizeFetchToMem" -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) ) /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index efffdca1ea59b..e7b406af8d9b1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -357,13 +357,15 @@ package object config { .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) - private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = - ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = + ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") + .doc("Remote block will be fetched to disk when size of the block is " + "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note that this config can " + - "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + - " service is disabled.") + "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + + "affect both shuffle fetch and block manager remote block fetch. For users who " + + "enabled external shuffle service, this feature can only be worked when external shuffle" + + " service is newer than Spark 2.2.") + .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -432,5 +434,4 @@ package object config { .stringConf .toSequence .createOptional - } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index fe5fd2da039bb..1d8a266d0079c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -25,8 +25,8 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit + tempFileManager: TempFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -87,7 +87,12 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { + def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -96,12 +101,17 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result.success(new NioManagedBuffer(ret)) + data match { + case f: FileSegmentManagedBuffer => + result.success(f) + case _ => + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) + } } - }, tempShuffleFileManager = null) + }, tempFileManager) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 6a29e18bf3cbb..b7d8c35032763 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -105,14 +105,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempShuffleFileManager).start() + transportConf, tempFileManager).start() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c8d1460300934..0562d45ff57c5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,7 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a98083df5bd84..e0276a4dc4224 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,8 +18,11 @@ package org.apache.spark.storage import java.io._ +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.collection.mutable.HashMap @@ -39,7 +42,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{SerializerInstance, SerializerManager} @@ -203,6 +206,13 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + // A TempFileManager used to track all the files of remote blocks which above the + // specified memory threshold. Files will be deleted automatically based on weak reference. + // Exposed for test + private[storage] val remoteBlockTempFileManager = + new BlockManager.RemoteBlockTempFileManager(this) + private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -632,8 +642,8 @@ private[spark] class BlockManager( * Return a list of locations for the given block, prioritizing the local machine since * multiple block managers can share the same host, followed by hosts on the same rack. */ - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - val locs = Random.shuffle(master.getLocations(blockId)) + private def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + val locs = Random.shuffle(locations) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } blockManagerId.topologyInfo match { case None => preferredLocs ++ otherLocs @@ -653,7 +663,25 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") var runningFailureCount = 0 var totalFailureCount = 0 - val locations = getLocations(blockId) + + // Because all the remote blocks are registered in driver, it is not necessary to ask + // all the slave executors to get block status. + val locationsAndStatus = master.getLocationsAndStatus(blockId) + val blockSize = locationsAndStatus.map { b => + b.status.diskSize.max(b.status.memSize) + }.getOrElse(0L) + val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) + + // If the block size is above the threshold, we should pass our FileManger to + // BlockTransferService, which will leverage it to spill the block; if not, then passed-in + // null value means the block will be persisted in memory. + val tempFileManager = if (blockSize > maxRemoteBlockToMem) { + remoteBlockTempFileManager + } else { + null + } + + val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator while (locationIterator.hasNext) { @@ -661,7 +689,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 @@ -684,7 +712,7 @@ private[spark] class BlockManager( // take a significant amount of time. To get rid of these stale entries // we refresh the block locations after a certain number of fetch failures if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { - locationIterator = getLocations(blockId).iterator + locationIterator = sortLocations(master.getLocations(blockId)).iterator logDebug(s"Refreshed locations from the driver " + s"after ${runningFailureCount} fetch failures.") runningFailureCount = 0 @@ -1512,6 +1540,7 @@ private[spark] class BlockManager( // Closing should be idempotent, but maybe not for the NioBlockTransferService. shuffleClient.close() } + remoteBlockTempFileManager.stop() diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() @@ -1552,4 +1581,65 @@ private[spark] object BlockManager { override val metricRegistry = new MetricRegistry metricRegistry.registerAll(metricSet) } + + class RemoteBlockTempFileManager(blockManager: BlockManager) + extends TempFileManager with Logging { + + private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File]) + extends WeakReference[File](file, referenceQueue) { + private val filePath = file.getAbsolutePath + + def cleanUp(): Unit = { + logDebug(s"Clean up file $filePath") + + if (!new File(filePath).delete()) { + logDebug(s"Fail to delete file $filePath") + } + } + } + + private val referenceQueue = new JReferenceQueue[File] + private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup]( + new ConcurrentHashMap) + + private val POLL_TIMEOUT = 1000 + @volatile private var stopped = false + + private val cleaningThread = new Thread() { override def run() { keepCleaning() } } + cleaningThread.setDaemon(true) + cleaningThread.setName("RemoteBlock-temp-file-clean-thread") + cleaningThread.start() + + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempFileToClean(file: File): Boolean = { + referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue)) + } + + def stop(): Unit = { + stopped = true + cleaningThread.interrupt() + cleaningThread.join() + } + + private def keepCleaning(): Unit = { + while (!stopped) { + try { + Option(referenceQueue.remove(POLL_TIMEOUT)) + .map(_.asInstanceOf[ReferenceWithCleanup]) + .foreach { ref => + referenceBuffer.remove(ref) + ref.cleanUp() + } + } catch { + case _: InterruptedException => + // no-op + case NonFatal(e) => + logError("Error in cleaning thread", e) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8b1dc0ba6356a..d24421b962774 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,6 +84,12 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } + /** Get locations as well as status of the blockId from the driver */ + def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + driverEndpoint.askSync[Option[BlockLocationsAndStatus]]( + GetLocationsAndStatus(blockId)) + } + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index df0a5f5e229fb..56d0266b8edad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -82,6 +82,9 @@ class BlockManagerMasterEndpoint( case GetLocations(blockId) => context.reply(getLocations(blockId)) + case GetLocationsAndStatus(blockId) => + context.reply(getLocationsAndStatus(blockId)) + case GetLocationsMultipleBlockIds(blockIds) => context.reply(getLocationsMultipleBlockIds(blockIds)) @@ -422,6 +425,17 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } + private def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) + val status = locations.headOption.flatMap { bmId => blockManagerInfo(bmId).getStatus(blockId) } + + if (locations.nonEmpty && status.isDefined) { + Some(BlockLocationsAndStatus(locations, status.get)) + } else { + None + } + } + private def getLocationsMultipleBlockIds( blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0c0ff144596ac..1bbe7a5b39509 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -93,6 +93,13 @@ private[spark] object BlockManagerMessages { case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster + case class GetLocationsAndStatus(blockId: BlockId) extends ToBlockManagerMaster + + // The response message of `GetLocationsAndStatus` request. + case class BlockLocationsAndStatus(locations: Seq[BlockManagerId], status: BlockStatus) { + assert(locations.nonEmpty) + } + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2d176b62f8b36..98b5a735a4529 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -69,7 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -162,11 +162,11 @@ final class ShuffleBlockFetcherIterator( currentResult = null } - override def createTempShuffleFile(): File = { + override def createTempFile(): File = { blockManager.diskBlockManager.createTempLocalBlock()._2 } - override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + override def registerTempFileToClean(file: File): Boolean = synchronized { if (isZombie) { false } else { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index bea67b71a5a12..f8005610f7e4f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString) + blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index cfe89fde63f88..d45c194d31adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.storage -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -45,14 +44,14 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -512,8 +511,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) val blockManager = makeBlockManager(128, "exec", bmMaster) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) } @@ -535,8 +534,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) blockManager.blockManagerId = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) assert(locations.flatMap(_.topologyInfo) === Seq(localRack, localRack, localRack, otherRack, otherRack)) @@ -1274,13 +1273,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // so that we have a chance to do location refresh val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } - when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockManagerIds, BlockStatus.empty))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn( + blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, transferService = Option(mockBlockTransferService)) val block = store.getRemoteBytes("item") .asInstanceOf[Option[ByteBuffer]] assert(block.isDefined) - verify(mockBlockManagerMaster, times(2)).getLocations("item") + verify(mockBlockManagerMaster, times(1)).getLocationsAndStatus("item") + verify(mockBlockManagerMaster, times(1)).getLocations("item") } test("SPARK-17484: block status is properly updated following an exception in put()") { @@ -1371,8 +1375,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE server.close() } + test("fetch remote block to local disk if block size is larger than threshold") { + conf.set(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1000L) + + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = new MockBlockTransferService(0) + val blockLocations = Seq(BlockManagerId("id-0", "host-0", 1)) + val blockStatus = BlockStatus(StorageLevel.DISK_ONLY, 0L, 2000L) + + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockLocations, blockStatus))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockLocations) + + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + + assert(block.isDefined) + assert(mockBlockTransferService.numCalls === 1) + // assert FileManager is not null if the block size is larger than threshold. + assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 + var tempFileManager: TempFileManager = null override def init(blockDataManager: BlockDataManager): Unit = {} @@ -1382,7 +1410,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1394,7 +1422,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def uploadBlock( hostname: String, - port: Int, execId: String, + port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel, @@ -1407,12 +1436,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE host: String, port: Int, execId: String, - blockId: String): ManagedBuffer = { + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { numCalls += 1 + this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c371cbcf8dff5..5bfe9905ff17b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -437,12 +437,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempShuffleFileManager: TempShuffleFileManager = null + var tempFileManager: TempFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] + tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -472,13 +472,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempShuffleFileManager == null) + assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempShuffleFileManager != null) + assert(tempFileManager != null) } } diff --git a/docs/configuration.md b/docs/configuration.md index 7a777d3c6fa3d..bb06c8faaaed7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -547,13 +547,14 @@ Apart from these, the following properties are also available, and may be useful - spark.reducer.maxReqSizeShuffleToMem + spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. We can enable this - config by setting a specific value(e.g. 200m). Note that this config can be enabled only when - the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + The remote block will be fetched to disk when size of the block is above this threshold. + This is to avoid a giant request takes too much memory. We can enable this config by setting + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + and block manager remote block fetch. For users who enabled external shuffle service, + this feature can only be worked when external shuffle service is newer than Spark 2.2. From 75d666b95a711787355ca3895057dabadd429023 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Oct 2017 12:26:53 -0700 Subject: [PATCH 075/712] [SPARK-22136][SS] Evaluate one-sided conditions early in stream-stream joins. ## What changes were proposed in this pull request? Evaluate one-sided conditions early in stream-stream joins. This is in addition to normal filter pushdown, because integrating it with the join logic allows it to take place in outer join scenarios. This means that rows which can never satisfy the join condition won't clog up the state. ## How was this patch tested? new unit tests Author: Jose Torres Closes #19452 from joseph-torres/SPARK-22136. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 134 ++++++++++------ .../StreamingSymmetricHashJoinHelper.scala | 70 +++++++- .../sql/streaming/StreamingJoinSuite.scala | 150 +++++++++++++++++- ...treamingSymmetricHashJoinHelperSuite.scala | 130 +++++++++++++++ 5 files changed, 433 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 2e378637727fc..a10ed5f2df1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -133,7 +133,7 @@ class IncrementalExecution( eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( - j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition, + j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, Some(offsetSeqMetadata.batchWatermarkMs)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 9bd2127a28ff6..c351f658cb955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.{LongType, TimestampType} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -115,7 +114,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param leftKeys Expression to generate key rows for joining from left input * @param rightKeys Expression to generate key rows for joining from right input * @param joinType Type of join (inner, left outer, etc.) - * @param condition Optional, additional condition to filter output of the equi-join + * @param condition Conditions to filter rows, split by left, right, and joined. See + * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) * @param eventTimeWatermark Watermark of input event, same for both sides * @param stateWatermarkPredicates Predicates for removal of state, see @@ -127,7 +127,7 @@ case class StreamingSymmetricHashJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression], + condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, @@ -141,8 +141,10 @@ case class StreamingSymmetricHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) = { + this( - leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None, + leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), + stateInfo = None, eventTimeWatermark = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } @@ -161,6 +163,9 @@ case class StreamingSymmetricHashJoinExec( new SerializableConfiguration(SessionState.newHadoopConf( sparkContext.hadoopConfiguration, sqlContext.conf))) + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -206,10 +211,15 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow + + val postJoinFilter = + newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ val leftSideJoiner = new OneSideHashJoiner( - LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left) + LeftSide, left.output, leftKeys, leftInputIter, + condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) val rightSideJoiner = new OneSideHashJoiner( - RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right) + RightSide, right.output, rightKeys, rightInputIter, + condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right) // Join one side input using the other side's buffered/state rows. Here is how it is done. // @@ -221,43 +231,28 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input) } - // Filter the joined rows based on the given condition. - val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ - // We need to save the time that the inner join output iterator completes, since outer join // output counts as both update and removal time. var innerOutputCompletionTimeNs: Long = 0 def onInnerOutputCompletion = { innerOutputCompletionTimeNs = System.nanoTime } - val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) - - def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { - rightSideJoiner.get(leftKeyValue.key).exists( - rightValue => { - outputFilterFunction( - joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) - }) - } + // This is the iterator which produces the inner join rows. For outer joins, this will be + // prepended to a second iterator producing outer join rows; for inner joins, this is the full + // output. + val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) - def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { - leftSideJoiner.get(rightKeyValue.key).exists( - leftValue => { - outputFilterFunction( - joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) - }) - } val outputIter: Iterator[InternalRow] = joinType match { case Inner => - filteredInnerOutputIter + innerOutputIter case LeftOuter => // We generate the outer join input by: // * Getting an iterator over the rows that have aged out on the left side. These rows are @@ -268,28 +263,37 @@ case class StreamingSymmetricHashJoinExec( // we know we can join with null, since there was never (including this batch) a match // within the watermark period. If it does, there must have been a match at some point, so // we know we can't join with null. - val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists { rightValue => + postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + } + } val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithRightSideState(pair)) .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case RightOuter => // See comments for left outer case. - val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists { leftValue => + postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + } + } val removedRowIter = rightSideJoiner.removeOldState() val outerOutputIter = removedRowIter .filterNot(pair => matchesWithLeftSideState(pair)) .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) - filteredInnerOutputIter ++ outerOutputIter + innerOutputIter ++ outerOutputIter case _ => throwBadJoinTypeException() } + val outputProjection = UnsafeProjection.create(left.output ++ right.output, output) val outputIterWithMetrics = outputIter.map { row => numOutputRows += 1 - row + outputProjection(row) } // Function to remove old state after all the input has been consumed and output generated @@ -349,14 +353,36 @@ case class StreamingSymmetricHashJoinExec( /** * Internal helper class to consume input rows, generate join output rows using other sides * buffered state rows, and finally clean up this sides buffered state rows + * + * @param joinSide The JoinSide - either left or right. + * @param inputAttributes The input attributes for this side of the join. + * @param joinKeys The join keys. + * @param inputIter The iterator of input rows on this side to be joined. + * @param preJoinFilterExpr A filter over rows on this side. This filter rejects rows that could + * never pass the overall join condition no matter what other side row + * they're joined with. + * @param postJoinFilter A filter over joined rows. This filter completes the application of + * the overall join condition, assuming that preJoinFilter on both sides + * of the join has already been passed. + * Passed as a function rather than expression to avoid creating the + * predicate twice; we also need this filter later on in the parent exec. + * @param stateWatermarkPredicate The state watermark predicate. See + * [[StreamingSymmetricHashJoinExec]] for further description of + * state watermarks. */ private class OneSideHashJoiner( joinSide: JoinSide, inputAttributes: Seq[Attribute], joinKeys: Seq[Expression], inputIter: Iterator[InternalRow], + preJoinFilterExpr: Option[Expression], + postJoinFilter: (InternalRow) => Boolean, stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) { + // Filter the joined rows based on the given condition. + val preJoinFilter = + newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) @@ -388,8 +414,8 @@ case class StreamingSymmetricHashJoinExec( */ def storeAndJoinWithOtherSide( otherSideJoiner: OneSideHashJoiner)( - generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = { - + generateJoinedRow: (InternalRow, InternalRow) => JoinedRow): + Iterator[InternalRow] = { val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { @@ -402,17 +428,31 @@ case class StreamingSymmetricHashJoinExec( nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] - val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) - } - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 + // If this row fails the pre join filter, that means it can never satisfy the full join + // condition no matter what other side row it's matched with. This allows us to avoid + // adding it to the state, and generate an outer join row immediately (or do nothing in + // the case of inner join). + if (preJoinFilter(thisRow)) { + val key = keyGenerator(thisRow) + val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => + generateJoinedRow(thisRow, thatRow) + }.filter(postJoinFilter) + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow) + updatedStateRowsCount += 1 + } + outputIter + } else { + joinSide match { + case LeftSide if joinType == LeftOuter => + Iterator(generateJoinedRow(thisRow, nullRight)) + case RightSide if joinType == RightOuter => + Iterator(generateJoinedRow(thisRow, nullLeft)) + case _ => Iterator() + } } - outputIter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 64c7189f72ac3..167e991ca62f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -24,8 +24,9 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper -import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.types._ @@ -66,6 +67,73 @@ object StreamingSymmetricHashJoinHelper extends Logging { } } + /** + * Wrapper around various useful splits of the join condition. + * left AND right AND joined is equivalent to full. + * + * Note that left and right do not necessarily contain *all* conjuncts which satisfy + * their condition. Any conjuncts after the first nondeterministic one are treated as + * nondeterministic for purposes of the split. + * + * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. + * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. + * @param bothSides Conjuncts which are nondeterministic, occur after a nondeterministic conjunct, + * or reference both left and right sides of the join. + * @param full The full join condition. + */ + case class JoinConditionSplitPredicates( + leftSideOnly: Option[Expression], + rightSideOnly: Option[Expression], + bothSides: Option[Expression], + full: Option[Expression]) { + override def toString(): String = { + s"condition = [ leftOnly = ${leftSideOnly.map(_.toString).getOrElse("null")}, " + + s"rightOnly = ${rightSideOnly.map(_.toString).getOrElse("null")}, " + + s"both = ${bothSides.map(_.toString).getOrElse("null")}, " + + s"full = ${full.map(_.toString).getOrElse("null")} ]" + } + } + + object JoinConditionSplitPredicates extends PredicateHelper { + def apply(condition: Option[Expression], left: SparkPlan, right: SparkPlan): + JoinConditionSplitPredicates = { + // Split the condition into 3 parts: + // * Conjuncts that can be evaluated on only the left input. + // * Conjuncts that can be evaluated on only the right input. + // * Conjuncts that require both left and right input. + // + // Note that we treat nondeterministic conjuncts as though they require both left and right + // input. To maintain their semantics, they need to be evaluated exactly once per joined row. + val (leftCondition, rightCondition, joinedCondition) = { + if (condition.isEmpty) { + (None, None, None) + } else { + // Span rather than partition, because nondeterministic expressions don't commute + // across AND. + val (deterministicConjuncts, nonDeterministicConjuncts) = + splitConjunctivePredicates(condition.get).span(_.deterministic) + + val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(left.outputSet) + } + + val (rightConjuncts, nonRightConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(right.outputSet) + } + + ( + leftConjuncts.reduceOption(And), + rightConjuncts.reduceOption(And), + (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts) + .reduceOption(And) + ) + } + } + + JoinConditionSplitPredicates(leftCondition, rightCondition, joinedCondition, condition) + } + } + /** Get the predicates defining the state watermarks for both sides of the join */ def getStateWatermarkPredicates( leftAttributes: Seq[Attribute], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index d32617275aadc..54eb863dacc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -365,6 +365,24 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with } } } + + test("join between three streams") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val input3 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") + val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + + val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) + + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, 1, 5, 10), + AddData(input3, 5, 10), + CheckLastBatch((5, 10, 5, 15, 5, 25))) + } } class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { @@ -405,6 +423,130 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with (input1, input2, joined) } + test("left outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + + test("left outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with value <= 7 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + ) + } + + test("right outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with value <= 4 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + ) + } + + test("right outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + test("windowed left outer join") { val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") @@ -495,7 +637,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // When the join condition isn't true, the outer null rows must be generated, even if the join // keys themselves have a match. - test("left outer join with non-key condition violated on left") { + test("left outer join with non-key condition violated") { val (leftInput, simpleLeftDf) = setupStream("left", 2) val (rightInput, simpleRightDf) = setupStream("right", 3) @@ -513,14 +655,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // leftValue <= 10 should generate outer join rows even though it matches right keys AddData(leftInput, 1, 2, 3), AddData(rightInput, 1, 2, 3), - CheckLastBatch(), + CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), AddData(leftInput, 20), AddData(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), CheckLastBatch( - Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows AddData(leftInput, 40, 41), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala new file mode 100644 index 0000000000000..2a854e37bf0df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates +import org.apache.spark.sql.types._ + +class StreamingSymmetricHashJoinHelperSuite extends StreamTest { + import org.apache.spark.sql.functions._ + + val leftAttributeA = AttributeReference("a", IntegerType)() + val leftAttributeB = AttributeReference("b", IntegerType)() + val rightAttributeC = AttributeReference("c", IntegerType)() + val rightAttributeD = AttributeReference("d", IntegerType)() + val leftColA = new Column(leftAttributeA) + val leftColB = new Column(leftAttributeB) + val rightColC = new Column(rightAttributeC) + val rightColD = new Column(rightAttributeD) + + val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq()) + val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq()) + + test("empty") { + val split = JoinConditionSplitPredicates(None, left, right) + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.isEmpty) + } + + test("only literals") { + // Literal-only conjuncts end up on the left side because that's the first bucket they fit in. + // There's no semantic reason they couldn't be in any bucket. + val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only left") { + val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only right") { + val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("mixed conjuncts") { + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC).expr)) + assert(split.full.contains(predicate)) + } + + test("conjuncts after nondeterministic") { + // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't + // commute across it. + val predicate = + (rand() > lit(0) + && leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.contains(predicate)) + assert(split.full.contains(predicate)) + } + + + test("conjuncts before nondeterministic") { + val randCol = rand() + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1) + && randCol > lit(0)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr)) + assert(split.full.contains(predicate)) + } +} From 28f9f3f22511e9f2f900764d9bd5b90d2eeee773 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 17 Oct 2017 12:50:41 -0700 Subject: [PATCH 076/712] [SPARK-22271][SQL] mean overflows and returns null for some decimal variables ## What changes were proposed in this pull request? In Average.scala, it has ``` override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } def setChild (newchild: Expression) = { child = newchild } ``` It is possible that Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow. Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0) ## How was this patch tested? In DataFrameSuite, I will add a test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huaxin Gao Closes #19496 from huaxingao/spark-22271. --- .../sql/catalyst/expressions/aggregate/Average.scala | 3 ++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17169e85..708bdbfc36058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50de2fd3bca8d..473c355cf3c7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d = 0.034567890 + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } From 1437e344ec0c29a44a19f4513986f5f184c44695 Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Tue, 17 Oct 2017 14:30:52 -0700 Subject: [PATCH 077/712] [SPARK-22050][CORE] Allow BlockUpdated events to be optionally logged to the event log ## What changes were proposed in this pull request? I see that block updates are not logged to the event log. This makes sense as a default for performance reasons. However, I find it helpful when trying to get a better understanding of caching for a job to be able to log these updates. This PR adds a configuration setting `spark.eventLog.blockUpdates` (defaulting to false) which allows block updates to be recorded in the log. This contribution is original work which is licensed to the Apache Spark project. ## How was this patch tested? Current and additional unit tests. Author: Michael Mior Closes #19263 from michaelmior/log-block-updates. --- .../spark/internal/config/package.scala | 23 +++++++++++++ .../scheduler/EventLoggingListener.scala | 18 ++++++---- .../org/apache/spark/util/JsonProtocol.scala | 34 +++++++++++++++++-- .../scheduler/EventLoggingListenerSuite.scala | 2 ++ .../apache/spark/util/JsonProtocolSuite.scala | 27 +++++++++++++++ docs/configuration.md | 8 +++++ 6 files changed, 104 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e7b406af8d9b1..0c36bdcdd2904 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -41,6 +41,29 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val EVENT_LOG_COMPRESS = + ConfigBuilder("spark.eventLog.compress") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_BLOCK_UPDATES = + ConfigBuilder("spark.eventLog.logBlockUpdates.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_TESTING = + ConfigBuilder("spark.eventLog.testing") + .internal() + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .bytesConf(ByteUnit.KiB) + .createWithDefaultString("100k") + + private[spark] val EVENT_LOG_OVERWRITE = + ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 9dafa0b7646bf..a77adc5ff3545 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -37,6 +37,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -45,6 +46,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * * Event logging is specified by the following configurable parameters: * spark.eventLog.enabled - Whether event logging is enabled. + * spark.eventLog.logBlockUpdates.enabled - Whether to log block updates * spark.eventLog.compress - Whether to compress logged events * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. @@ -64,10 +66,11 @@ private[spark] class EventLoggingListener( this(appId, appAttemptId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) - private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) - private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) - private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) - private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 + private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) + private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) + private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val testing = sparkConf.get(EVENT_LOG_TESTING) + private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { @@ -216,8 +219,11 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } - // No-op because logging every update would be overkill - override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + if (shouldLogBlockUpdates) { + logEvent(event, flushLogger = true) + } + } // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8406826a228db..5e60218c5740b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -98,8 +98,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) - case blockUpdated: SparkListenerBlockUpdated => - throw new MatchError(blockUpdated) // TODO(ekl) implement this + case blockUpdate: SparkListenerBlockUpdated => + blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) } } @@ -246,6 +246,12 @@ private[spark] object JsonProtocol { }) } + def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { + val blockUpdatedInfo = blockUpdatedInfoToJson(blockUpdate.blockUpdatedInfo) + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockUpdate) ~ + ("Block Updated Info" -> blockUpdatedInfo) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -458,6 +464,14 @@ private[spark] object JsonProtocol { ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } + def blockUpdatedInfoToJson(blockUpdatedInfo: BlockUpdatedInfo): JValue = { + ("Block Manager ID" -> blockManagerIdToJson(blockUpdatedInfo.blockManagerId)) ~ + ("Block ID" -> blockUpdatedInfo.blockId.toString) ~ + ("Storage Level" -> storageLevelToJson(blockUpdatedInfo.storageLevel)) ~ + ("Memory Size" -> blockUpdatedInfo.memSize) ~ + ("Disk Size" -> blockUpdatedInfo.diskSize) + } + /** ------------------------------ * * Util JSON serialization methods | * ------------------------------- */ @@ -515,6 +529,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -538,6 +553,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -676,6 +692,11 @@ private[spark] object JsonProtocol { SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } + def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { + val blockUpdatedInfo = blockUpdatedInfoFromJson(json \ "Block Updated Info") + SparkListenerBlockUpdated(blockUpdatedInfo) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -989,6 +1010,15 @@ private[spark] object JsonProtocol { new ExecutorInfo(executorHost, totalCores, logUrls) } + def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { + val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") + val blockId = BlockId((json \ "Block ID").extract[String]) + val storageLevel = storageLevelFromJson(json \ "Storage Level") + val memorySize = (json \ "Memory Size").extract[Long] + val diskSize = (json \ "Disk Size").extract[Long] + BlockUpdatedInfo(blockManagerId, blockId, storageLevel, memorySize, diskSize) + } + /** -------------------------------- * * Util JSON deserialization methods | * --------------------------------- */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 6b42775ccb0f6..a9e92fa07b9dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -228,6 +228,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerStageCompleted, SparkListenerTaskStart, SparkListenerTaskEnd, + SparkListenerBlockUpdated, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) Utils.tryWithSafeFinally { val logStart = SparkListenerLogStart(SPARK_VERSION) @@ -291,6 +292,7 @@ object EventLoggingListenerSuite { def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") + conf.set("spark.eventLog.logBlockUpdates.enabled", "true") conf.set("spark.eventLog.testing", "true") conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a1a858765a7d4..4abbb8e7894f5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,6 +96,9 @@ class JsonProtocolSuite extends SparkFunSuite { .zipWithIndex.map { case (a, i) => a.copy(id = i) } SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) } + val blockUpdated = + SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", + "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -120,6 +123,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeBlacklisted, nodeBlacklistedJsonString) testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) + testEvent(blockUpdated, blockUpdatedJsonString) } test("Dependent Classes") { @@ -2007,6 +2011,29 @@ private[spark] object JsonProtocolSuite extends Assertions { |} """.stripMargin + private val blockUpdatedJsonString = + """ + |{ + | "Event": "SparkListenerBlockUpdated", + | "Block Updated Info": { + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300 + | }, + | "Block ID": "rdd_0_0", + | "Storage Level": { + | "Use Disk": false, + | "Use Memory": true, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Memory Size": 100, + | "Disk Size": 0 + | } + |} + """.stripMargin + private val executorBlacklistedJsonString = s""" |{ diff --git a/docs/configuration.md b/docs/configuration.md index bb06c8faaaed7..7b9e16a382449 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -714,6 +714,14 @@ Apart from these, the following properties are also available, and may be useful + + + + + From f3137feecd30c74c47dbddb0e22b4ddf8cf2f912 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Oct 2017 20:09:12 -0700 Subject: [PATCH 078/712] [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState ## What changes were proposed in this pull request? Complex state-updating and/or timeout-handling logic in mapGroupsWithState functions may require taking decisions based on the current event-time watermark and/or processing time. Currently, you can use the SQL function `current_timestamp` to get the current processing time, but it needs to be passed inserted in every row with a select, and then passed through the encoder, which isn't efficient. Furthermore, there is no way to get the current watermark. This PR exposes both of them through the GroupState API. Additionally, it also cleans up some of the GroupState docs. ## How was this patch tested? New unit tests Author: Tathagata Das Closes #19495 from tdas/SPARK-22278. --- .../apache/spark/sql/execution/objects.scala | 8 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../execution/streaming/GroupStateImpl.scala | 50 +++--- .../spark/sql/streaming/GroupState.scala | 92 ++++++---- .../FlatMapGroupsWithStateSuite.scala | 160 +++++++++++++++--- 5 files changed, 238 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c68975bea490f..d861109436a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.plans.logical.{FunctionUtils, LogicalGroupState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.streaming.GroupStateTimeout @@ -361,8 +361,12 @@ object MapGroupsExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } val f = (key: Any, values: Iterator[Any]) => { - func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + func(key, values, GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)) } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index c81f1a8142784..29f38fab3f896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -61,6 +61,10 @@ case class FlatMapGroupsWithStateExec( private val isTimeoutEnabled = timeoutConf != NoTimeout val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -190,7 +194,8 @@ case class FlatMapGroupsWithStateExec( batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, - hasTimedOut) + hasTimedOut, + watermarkPresent) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 4401e86936af9..7f65e3ea9dd5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -43,7 +43,8 @@ private[sql] class GroupStateImpl[S] private( batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - override val hasTimedOut: Boolean) extends GroupState[S] { + override val hasTimedOut: Boolean, + watermarkPresent: Boolean) extends GroupState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -90,7 +91,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout duration without enabling processing time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMap]GroupsWithState") } if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") @@ -102,10 +103,6 @@ private[sql] class GroupStateImpl[S] private( setTimeoutDuration(parseDuration(duration)) } - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long): Unit = { checkTimeoutTimestampAllowed() if (timestampMs <= 0) { @@ -119,32 +116,34 @@ private[sql] class GroupStateImpl[S] private( timeoutTimestamp = timestampMs } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) } - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime) } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } + override def getCurrentWatermarkMs(): Long = { + if (!watermarkPresent) { + throw new UnsupportedOperationException( + "Cannot get event time watermark timestamp without setting watermark before " + + "[map|flatMap]GroupsWithState") + } + eventTimeWatermarkMs + } + + override def getCurrentProcessingTimeMs(): Long = { + batchProcessingTimeMs + } + override def toString: String = { s"GroupState(${getOption.map(_.toString).getOrElse("")})" } @@ -187,7 +186,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != EventTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout timestamp without enabling event time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMapGroupsWithState") } } } @@ -202,17 +201,22 @@ private[sql] object GroupStateImpl { batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - hasTimedOut: Boolean): GroupStateImpl[S] = { + hasTimedOut: Boolean, + watermarkPresent: Boolean): GroupStateImpl[S] = { new GroupStateImpl[S]( - optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, + timeoutConf, hasTimedOut, watermarkPresent) } - def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + def createForBatch( + timeoutConf: GroupStateTimeout, + watermarkPresent: Boolean): GroupStateImpl[Any] = { new GroupStateImpl[Any]( optionalValue = None, - batchProcessingTimeMs = NO_TIMESTAMP, + batchProcessingTimeMs = System.currentTimeMillis, eventTimeWatermarkMs = NO_TIMESTAMP, timeoutConf, - hasTimedOut = false) + hasTimedOut = false, + watermarkPresent) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 04a956b70b022..e9510c903acae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -205,11 +205,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** Get the state value as a scala Option. */ def getOption: Option[S] - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") + /** Update the value of the state. */ def update(newState: S): Unit /** Remove this state. */ @@ -217,80 +213,114 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Whether the function has been called because the key has timed out. - * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. */ def hasTimedOut: Boolean + /** * Set the timeout duration in ms for this key. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(durationMs: Long): Unit + /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(duration: String): Unit - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'timestampMs' is not positive or less than the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + + "the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit + + + /** + * Get the current event time watermark as milliseconds in epoch time. + * + * @note In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + */ + @throws[UnsupportedOperationException]( + "if watermark has not been set before in [map|flatMap]GroupsWithState") + def getCurrentWatermarkMs(): Long + + + /** + * Get the current processing time as milliseconds in epoch time. + * @note In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. + */ + def getCurrentProcessingTimeMs(): Long } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index aeb83835f981a..af08186aadbb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -48,6 +49,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ + import FlatMapGroupsWithStateSuite._ override def afterAll(): Unit = { super.afterAll() @@ -77,13 +79,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // === Tests for state in streaming queries === // Updating empty state - state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -104,8 +108,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with NoTimeout") { for (initValue <- Seq(None, Some(5))) { val states = Seq( - GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), - GroupStateImpl.createForBatch(NoTimeout) + GroupStateImpl.createForStreaming( + initValue, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false), + GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) ) for (state <- states) { // for streaming queries @@ -122,7 +127,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with ProcessingTimeTimeout") { // for streaming queries var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state @@ -143,7 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch( + ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -160,7 +166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("GroupState - setTimeout - with EventTimeTimeout") { var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, EventTimeTimeout, false) + None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -182,7 +188,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testTimeoutDurationNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = false) + .asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) @@ -209,7 +216,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -227,7 +234,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -259,29 +266,92 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // for streaming queries for (initState <- Seq(None, Some(5))) { val state1 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = false) + initState, 1000, 1000, timeoutConf, hasTimedOut = false, watermarkPresent = false) assert(state1.hasTimedOut === false) val state2 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = true) + initState, 1000, 1000, timeoutConf, hasTimedOut = true, watermarkPresent = false) assert(state2.hasTimedOut === true) } // for batch queries - assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) + assert( + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) + } + } + + test("GroupState - getCurrentWatermarkMs") { + def streamingState(timeoutConf: GroupStateTimeout, watermark: Option[Long]): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, 1000, watermark.getOrElse(-1), timeoutConf, + hasTimedOut = false, watermark.nonEmpty) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + def assertWrongTimeoutError(test: => Unit): Unit = { + val e = intercept[UnsupportedOperationException] { test } + assert(e.getMessage.contains( + "Cannot get event time watermark timestamp without setting watermark")) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + // Tests for getCurrentWatermarkMs in streaming queries + assertWrongTimeoutError { streamingState(timeoutConf, None).getCurrentWatermarkMs() } + assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() === 2000) + + // Tests for getCurrentWatermarkMs in batch queries + assertWrongTimeoutError { + batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() + } + assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) + } + } + + test("GroupState - getCurrentProcessingTimeMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + for (watermarkPresent <- Seq(false, true)) { + // Tests for getCurrentProcessingTimeMs in streaming queries + assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent) + .getCurrentProcessingTimeMs() === -1) + assert(streamingState(timeoutConf, 1000, watermarkPresent) + .getCurrentProcessingTimeMs() === 1000) + assert(streamingState(timeoutConf, 2000, watermarkPresent) + .getCurrentProcessingTimeMs() === 2000) + + // Tests for getCurrentProcessingTimeMs in batch queries + val currentTime = System.currentTimeMillis() + assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) + } } } + test("GroupState - primitive type") { var intState = GroupStateImpl.createForStreaming[Int]( - None, 1000, 1000, NoTimeout, hasTimedOut = false) + None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) intState = GroupStateImpl.createForStreaming[Int]( - Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -304,7 +374,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( testName + "no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change @@ -342,7 +416,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithData( s"$timeoutConf - $testName - no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -466,7 +544,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStateUpdateWithTimeout( s"$timeoutConf - should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change @@ -525,6 +607,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -647,6 +731,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } @@ -660,6 +747,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -713,10 +803,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout - val stateFunc = ( - key: String, - values: Iterator[(String, Long)], - state: GroupState[Long]) => { + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -760,6 +850,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -802,7 +894,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest // - no initial state // - timeouts operations work, does not throw any error [SPARK-20792] // - works with primitive state type + // - can get processing time val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") state.setTimeoutTimestamp(0, "1 hour") state.update(10) @@ -1090,4 +1186,24 @@ object FlatMapGroupsWithStateSuite { override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } + + def assertCanGetProcessingTime(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCanGetWatermark(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCannotGetWatermark(func: => Unit): Unit = { + try { + func + } catch { + case u: UnsupportedOperationException => + return + case _ => + throw new TestFailedException("Unexpected exception when trying to get watermark", 20) + } + throw new TestFailedException("Could get watermark when not expected", 20) + } } From 72561ecf4b611d68f8bf695ddd0c4c2cce3a29d9 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Oct 2017 20:59:40 +0800 Subject: [PATCH 079/712] [SPARK-22266][SQL] The same aggregate function was evaluated multiple times ## What changes were proposed in this pull request? To let the same aggregate function that appear multiple times in an Aggregate be evaluated only once, we need to deduplicate the aggregate expressions. The original code was trying to use a "distinct" call to get a set of aggregate expressions, but did not work, since the "distinct" did not compare semantic equality. And even if it did, further work should be done in result expression rewriting. In this PR, I changed the "set" to a map mapping the semantic identity of a aggregate expression to itself. Thus, later on, when rewriting result expressions (i.e., output expressions), the aggregate expression reference can be fixed. ## How was this patch tested? Added a new test in SQLQuerySuite Author: maryannxue Closes #19488 from maryannxue/spark-22266. --- .../sql/catalyst/planning/patterns.scala | 16 +++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a4960..cc391aae55787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -205,14 +205,17 @@ object PhysicalAggregation { case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. + // build a set of semantically distinct aggregate expressions and re-write expressions so + // that they reference the single copy of the aggregate function which actually gets computed. + // Non-deterministic aggregate expressions are not deduplicated. + val equivalentAggregateExpressions = new EquivalentExpressions val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // addExpr() always returns false for non-deterministic expressions and do not add them. + case agg: AggregateExpression + if (!equivalentAggregateExpressions.addExpr(agg)) => agg } - }.distinct + } val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -236,7 +239,8 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - ae.resultAttribute + equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f0c58e2e5bf45..caf332d050d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2715,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 1, 1)) } } + + test("SRARK-22266: the same aggregate function was calculated multiple times") { + val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 1) + checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) + } + + test("Non-deterministic aggregate functions should not be deduplicated") { + val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 2) + } } From 1f25d8683a84a479fd7fc77b5a1ea980289b681b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Oct 2017 09:14:46 -0700 Subject: [PATCH 080/712] [SPARK-22249][FOLLOWUP][SQL] Check if list of value for IN is empty in the optimizer ## What changes were proposed in this pull request? This PR addresses the comments by gatorsmile on [the previous PR](https://github.com/apache/spark/pull/19494). ## How was this patch tested? Previous UT and added UT. Author: Marco Gaido Closes #19522 from mgaido91/SPARK-22249_FOLLOWUP. --- .../execution/columnar/InMemoryTableScanExec.scala | 4 ++-- .../columnar/InMemoryColumnarQuerySuite.scala | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 846ec03e46a12..139da1c519da2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -102,8 +102,8 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(_: AttributeReference, list: Seq[Expression]) if list.isEmpty => Literal.FalseLiteral - case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => + case In(a: AttributeReference, list: Seq[Expression]) + if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 75d17bc79477d..2f249c850a088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -444,4 +445,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) dfNulls.unpersist() } + + test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { + val attribute = AttributeReference("a", IntegerType)() + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, + LocalTableScanExec(Seq(attribute), Nil), None) + val tableScanExec = InMemoryTableScanExec(Seq(attribute), + Seq(In(attribute, Nil)), testRelation) + assert(tableScanExec.partitionFilters.isEmpty) + } } From 52facb0062a4253fa45ac0c633d0510a9b684a62 Mon Sep 17 00:00:00 2001 From: Valeriy Avanesov Date: Wed, 18 Oct 2017 10:46:46 -0700 Subject: [PATCH 081/712] [SPARK-14371][MLLIB] OnlineLDAOptimizer should not collect stats for each doc in mini-batch to driver Hi, # What changes were proposed in this pull request? as it was proposed by jkbradley , ```gammat``` are not collected to the driver anymore. # How was this patch tested? existing test suite. Author: Valeriy Avanesov Author: Valeriy Avanesov Closes #18924 from akopich/master. --- .../spark/mllib/clustering/LDAOptimizer.scala | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d633893e55f55..693a2a31f026b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -26,6 +26,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.PeriodicGraphCheckpointer +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -259,7 +260,7 @@ final class EMLDAOptimizer extends LDAOptimizer { */ @Since("1.4.0") @DeveloperApi -final class OnlineLDAOptimizer extends LDAOptimizer { +final class OnlineLDAOptimizer extends LDAOptimizer with Logging { // LDA common parameters private var k: Int = 0 @@ -462,31 +463,61 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape - - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val optimizeDocConcentration = this.optimizeDocConcentration + // If and only if optimizeDocConcentration is set true, + // we calculate logphat in the same pass as other statistics. + // No calculation of loghat happens otherwise. + val logphatPartOptionBase = () => if (optimizeDocConcentration) { + Some(BDV.zeros[Double](k)) + } else { + None + } + + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = ( + u: (BDM[Double], Option[BDV[Double]], Long), + v: (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + if (nonEmptyDocsN == 0) { + logWarning("No non-empty documents were submitted in the batch.") + // Therefore, there is no need to update any of the model parameters + return this + } + + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= nonEmptyDocsN.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + this } @@ -503,21 +534,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Update alpha based on `gammat`, the inferred topic distributions for documents in the - * current mini-batch. Uses Newton-Rhapson method. + * Update alpha based on `logphat`. + * Uses Newton-Rhapson method. * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + * @param logphat Expectation of estimated log-posterior distribution of + * topics in a document averaged over the batch. + * @param nonEmptyDocsN number of non-empty documents */ - private def updateAlpha(gammat: BDM[Double]): Unit = { + private def updateAlpha(logphat: BDV[Double], nonEmptyDocsN: Double): Unit = { val weight = rho() - val N = gammat.rows.toDouble val alpha = this.alpha.asBreeze.toDenseVector - val logphat: BDV[Double] = - sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)).t / N - val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat) - val c = N * trigamma(sum(alpha)) - val q = -N * trigamma(alpha) + val gradf = nonEmptyDocsN * (-LDAUtils.dirichletExpectation(alpha) + logphat) + + val c = nonEmptyDocsN * trigamma(sum(alpha)) + val q = -nonEmptyDocsN * trigamma(alpha) val b = sum(gradf / q) / (1D / c + sum(1D / q)) val dalpha = -(gradf - b) / q From 6f1d0dea1cdda558c998179789b386f6e52b9e36 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 19 Oct 2017 13:30:55 +0800 Subject: [PATCH 082/712] [SPARK-22300][BUILD] Update ORC to 1.4.1 ## What changes were proposed in this pull request? Apache ORC 1.4.1 is released yesterday. - https://orc.apache.org/news/2017/10/16/ORC-1.4.1/ Like ORC-233 (Allow `orc.include.columns` to be empty), there are several important fixes. This PR updates Apache ORC dependency to use the latest one, 1.4.1. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #19521 from dongjoon-hyun/SPARK-22300. --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- pom.xml | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 76fcbd15869f1..6e2fc63d67108 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -149,8 +149,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cb20072bf8b30..c2bbc253d723a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.3.jar +aircompressor-0.8.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -150,8 +150,8 @@ netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar opencsv-2.3.jar -orc-core-1.4.0-nohive.jar -orc-mapreduce-1.4.0-nohive.jar +orc-core-1.4.1-nohive.jar +orc-mapreduce-1.4.1-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 9fac8b1e53788..b9c972855204a 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.0 + 1.4.1 nohive 1.6.0 9.3.20.v20170531 @@ -1712,6 +1712,10 @@ org.apache.hive hive-storage-api + + io.airlift + slice + From dc2714da50ecba1bf1fdf555a82a4314f763a76e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Oct 2017 14:56:48 +0800 Subject: [PATCH 083/712] [SPARK-22290][CORE] Avoid creating Hive delegation tokens when not necessary. Hive delegation tokens are only needed when the Spark driver has no access to the kerberos TGT. That happens only in two situations: - when using a proxy user - when using cluster mode without a keytab This change modifies the Hive provider so that it only generates delegation tokens in those situations, and tweaks the YARN AM so that it makes the proper user visible to the Hive code when running with keytabs, so that the TGT can be used instead of a delegation token. The effect of this change is that now it's possible to initialize multiple, non-concurrent SparkContext instances in the same JVM. Before, the second invocation would fail to fetch a new Hive delegation token, which then could make the second (or third or...) application fail once the token expired. With this change, the TGT will be used to authenticate to the HMS instead. This change also avoids polluting the current logged in user's credentials when launching applications. The credentials are copied only when running applications as a proxy user. This makes it possible to implement SPARK-11035 later, where multiple threads might be launching applications, and each app should have its own set of credentials. Tested by verifying HDFS and Hive access in following scenarios: - client and cluster mode - client and cluster mode with proxy user - client and cluster mode with principal / keytab - long-running cluster app with principal / keytab - pyspark app that creates (and stops) multiple SparkContext instances through its lifetime Author: Marcelo Vanzin Closes #19509 from vanzin/SPARK-22290. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 17 +++-- .../HBaseDelegationTokenProvider.scala | 4 +- .../HadoopDelegationTokenManager.scala | 2 +- .../HadoopDelegationTokenProvider.scala | 2 +- .../HadoopFSDelegationTokenProvider.scala | 4 +- .../HiveDelegationTokenProvider.scala | 20 +++++- docs/running-on-yarn.md | 9 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 69 +++++++++++++++---- .../org/apache/spark/deploy/yarn/Client.scala | 5 +- .../org/apache/spark/deploy/yarn/config.scala | 4 ++ .../sql/hive/client/HiveClientImpl.scala | 6 -- 11 files changed, 110 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 53775db251bc6..1fa10ab943f34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -61,13 +61,17 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { + createSparkUser().doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } + + def createSparkUser(): UserGroupInformation = { val user = Utils.getCurrentUserName() - logDebug("running as user: " + user) + logDebug("creating UGI for user: " + user) val ugi = UserGroupInformation.createRemoteUser(user) transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) + ugi } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -417,6 +421,11 @@ class SparkHadoopUtil extends Logging { creds.readTokenStorageStream(new DataInputStream(tokensBuf)) creds } + + def isProxyUser(ugi: UserGroupInformation): Boolean = { + ugi.getAuthenticationMethod() == UserGroupInformation.AuthenticationMethod.PROXY + } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 78b0e6b2cbf39..5dcde4ec3a8a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -56,7 +56,9 @@ private[security] class HBaseDelegationTokenProvider None } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index c134b7ebe38fa..483d0deec8070 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -115,7 +115,7 @@ private[spark] class HadoopDelegationTokenManager( hadoopConf: Configuration, creds: Credentials): Long = { delegationTokenProviders.values.flatMap { provider => - if (provider.delegationTokensRequired(hadoopConf)) { + if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) } else { logDebug(s"Service ${provider.serviceName} does not require a token." + diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala index 1ba245e84af4b..ed0905088ab25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -37,7 +37,7 @@ private[spark] trait HadoopDelegationTokenProvider { * Returns true if delegation tokens are required for this service. By default, it is based on * whether Hadoop security is enabled. */ - def delegationTokensRequired(hadoopConf: Configuration): Boolean + def delegationTokensRequired(sparkConf: SparkConf, hadoopConf: Configuration): Boolean /** * Obtain delegation tokens for this service and get the time of the next renewal. diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 300773c58b183..21ca669ea98f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -69,7 +69,9 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration nextRenewalDate } - def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index b31cc595ed83b..ece5ce79c650d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -31,7 +31,9 @@ import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils private[security] class HiveDelegationTokenProvider @@ -55,9 +57,21 @@ private[security] class HiveDelegationTokenProvider } } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { + // Delegation tokens are needed only when: + // - trying to connect to a secure metastore + // - either deploying in cluster mode without a keytab, or impersonating another user + // + // Other modes (such as client with or without keytab, or cluster mode with keytab) do not need + // a delegation token, since there's a valid kerberos TGT for the right user available to the + // driver, which is the only process that connects to the HMS. + val deployMode = sparkConf.get("spark.submit.deployMode", "client") UserGroupInformation.isSecurityEnabled && - hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty + hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty && + (SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser()) || + (deployMode == "cluster" && !sparkConf.contains(KEYTAB))) } override def obtainDelegationTokens( @@ -83,7 +97,7 @@ private[security] class HiveDelegationTokenProvider val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) - logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") + logDebug(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 432639588cc2b..9599d40c545b2 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -401,6 +401,15 @@ To use a custom metrics.properties for the application master and executors, upd Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) + + + + + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71d..f6167235f89e4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} +import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable.HashMap @@ -28,6 +29,7 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -49,10 +51,7 @@ import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster( - args: ApplicationMasterArguments, - client: YarnRMClient) - extends Logging { +private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Logging { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -62,6 +61,46 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null + private val ugi = { + val original = UserGroupInformation.getCurrentUser() + + // If a principal and keytab were provided, log in to kerberos, and set up a thread to + // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL + // of the TGT, use a configuration to define how often to check that a relogin is necessary. + // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. + val principal = sparkConf.get(PRINCIPAL).orNull + val keytab = sparkConf.get(KEYTAB).orNull + if (principal != null && keytab != null) { + UserGroupInformation.loginUserFromKeytab(principal, keytab) + + val renewer = new Thread() { + override def run(): Unit = Utils.tryLogNonFatalError { + while (true) { + TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) + UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() + } + } + } + renewer.setName("am-kerberos-renewer") + renewer.setDaemon(true) + renewer.start() + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. It also copies over any delegation tokens that might have been created by the + // client, which will then be transferred over when starting executors (until new ones + // are created by the periodic task). + val newUser = UserGroupInformation.getCurrentUser() + SparkHadoopUtil.get.transferCredentials(original, newUser) + newUser + } else { + SparkHadoopUtil.get.createSparkUser() + } + } + + private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { + def run: YarnRMClient = new YarnRMClient() + }) + // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -201,6 +240,13 @@ private[spark] class ApplicationMaster( } final def run(): Int = { + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + def run: Unit = runImpl() + }) + exitCode + } + + private def runImpl(): Unit = { try { val appAttemptId = client.getAttemptId() @@ -254,11 +300,6 @@ private[spark] class ApplicationMaster( } } - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserApplication which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { @@ -284,6 +325,9 @@ private[spark] class ApplicationMaster( credentialRenewerThread.join() } + // Call this to force generation of secret so it gets populated into the Hadoop UGI. + val securityMgr = new SecurityManager(sparkConf) + if (isClusterMode) { runDriver(securityMgr) } else { @@ -297,7 +341,6 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + e) } - exitCode } /** @@ -775,10 +818,8 @@ object ApplicationMaster extends Logging { sys.props(k) = v } } - SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient) - System.exit(master.run()) - } + master = new ApplicationMaster(amArgs) + System.exit(master.run()) } private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 64b2b4d4db549..1fe25c4ddaabf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -394,7 +394,10 @@ private[spark] class Client( if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the // Kerberos tgt to get delegations again in the client side. - UserGroupInformation.getCurrentUser.addCredentials(credentials) + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 187803cc6050b..e1af8ba087d6e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -347,6 +347,10 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(Long.MaxValue) + private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1m") + // The list of cache-related config entries. This is used by Client and the AM to clean // up the environment so that these settings do not appear on the web UI. private[yarn] val CACHE_CONFIGS = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a01c312d5e497..16c95c53b4201 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -111,12 +111,6 @@ private[hive] class HiveClientImpl( if (clientLoader.isolationOn) { // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) - // Set up kerberos credentials for UserGroupInformation.loginUser within current class loader - if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { - val principal = sparkConf.get("spark.yarn.principal") - val keytab = sparkConf.get("spark.yarn.keytab") - SparkHadoopUtil.get.loginUserFromKeytab(principal, keytab) - } try { newState() } finally { From 5a07aca4d464e96d75ea17bf6768e24b829872ec Mon Sep 17 00:00:00 2001 From: krishna-pandey Date: Thu, 19 Oct 2017 08:33:14 +0100 Subject: [PATCH 084/712] [SPARK-22188][CORE] Adding security headers for preventing XSS, MitM and MIME sniffing ## What changes were proposed in this pull request? The HTTP Strict-Transport-Security response header (often abbreviated as HSTS) is a security feature that lets a web site tell browsers that it should only be communicated with using HTTPS, instead of using HTTP. Note: The Strict-Transport-Security header is ignored by the browser when your site is accessed using HTTP; this is because an attacker may intercept HTTP connections and inject the header or remove it. When your site is accessed over HTTPS with no certificate errors, the browser knows your site is HTTPS capable and will honor the Strict-Transport-Security header. The HTTP X-XSS-Protection response header is a feature of Internet Explorer, Chrome and Safari that stops pages from loading when they detect reflected cross-site scripting (XSS) attacks. The HTTP X-Content-Type-Options response header is used to protect against MIME sniffing vulnerabilities. ## How was this patch tested? Checked on my system locally. screen shot 2017-10-03 at 6 49 20 pm Author: krishna-pandey Author: Krishna Pandey Closes #19419 from krishna-pandey/SPARK-22188. --- .../spark/internal/config/package.scala | 18 +++++++ .../org/apache/spark/ui/JettyUtils.scala | 9 ++++ docs/security.md | 47 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0c36bdcdd2904..6f0247b73070d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -452,6 +452,24 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val UI_X_XSS_PROTECTION = + ConfigBuilder("spark.ui.xXssProtection") + .doc("Value for HTTP X-XSS-Protection response header") + .stringConf + .createWithDefaultString("1; mode=block") + + private[spark] val UI_X_CONTENT_TYPE_OPTIONS = + ConfigBuilder("spark.ui.xContentTypeOptions.enabled") + .doc("Set to 'true' for setting X-Content-Type-Options HTTP response header to 'nosniff'") + .booleanConf + .createWithDefault(true) + + private[spark] val UI_STRICT_TRANSPORT_SECURITY = + ConfigBuilder("spark.ui.strictTransportSecurity") + .doc("Value for HTTP Strict Transport Security Response Header") + .stringConf + .createOptional + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 5ee04dad6ed4d..0adeb4058b6e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -89,6 +90,14 @@ private[spark] object JettyUtils extends Logging { val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("X-Frame-Options", xFrameOptionsValue) + response.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) + if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { + response.setHeader("X-Content-Type-Options", "nosniff") + } + if (request.getScheme == "https") { + conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( + response.setHeader("Strict-Transport-Security", _)) + } response.getWriter.print(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_FORBIDDEN) diff --git a/docs/security.md b/docs/security.md index 1d004003f9a32..15aadf07cf873 100644 --- a/docs/security.md +++ b/docs/security.md @@ -186,7 +186,54 @@ configure those ports.
    Property NameDefaultMeaning
    spark.eventLog.logBlockUpdates.enabledfalse + Whether to log events for every block update, if spark.eventLog.enabled is true. + *Warning*: This will increase the size of the event log considerably. +
    spark.eventLog.compress false
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    spark.yarn.config.gatewayPath (none)
    +### HTTP Security Headers + +Apache Spark can be configured to include HTTP Headers which aids in preventing Cross +Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP +Strict Transport Security. + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When value is set to "true", X-Content-Type-Options HTTP response header will be set + to "nosniff". Set "false" to disable. +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    + See the [configuration page](configuration.html) for more details on the security configuration parameters, and org.apache.spark.SecurityManager for implementation details about security. + From 7fae7995ba05e0333d1decb7ca74ddb7c1b448d7 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Fri, 20 Oct 2017 09:40:00 +0900 Subject: [PATCH 085/712] [SPARK-22268][BUILD] Fix lint-java ## What changes were proposed in this pull request? Fix java style issues ## How was this patch tested? Run `./dev/lint-java` locally since it's not run on Jenkins Author: Andrew Ash Closes #19486 from ash211/aash/fix-lint-java. --- .../unsafe/sort/UnsafeInMemorySorter.java | 9 ++++---- .../sort/UnsafeExternalSorterSuite.java | 21 +++++++++++-------- .../sort/UnsafeInMemorySorterSuite.java | 3 ++- .../v2/reader/SupportsPushDownFilters.java | 1 - 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 869ec908be1fb..3bb87a6ed653d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -172,10 +172,11 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); - // the call to consumer.allocateArray may trigger a spill - // which in turn access this instance and eventually re-enter this method and try to free the array again. - // by setting the array to null and its length to 0 we effectively make the spill code-path a no-op. - // setting the array to null also indicates that it has already been de-allocated which prevents a double de-allocation in free(). + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). array = null; usableCapacity = 0; pos = 0; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 6c5451d0fd2a5..d0d0334add0bf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -516,12 +516,13 @@ public void testOOMDuringSpill() throws Exception { for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { insertNumber(sorter, i); } - // we expect the next insert to attempt growing the pointerssArray - // first allocation is expected to fail, then a spill is triggered which attempts another allocation - // which also fails and we expect to see this OOM here. - // the original code messed with a released array within the spill code - // and ended up with a failed assertion. - // we also expect the location of the OOM to be org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + // we expect the next insert to attempt growing the pointerssArray first + // allocation is expected to fail, then a spill is triggered which + // attempts another allocation which also fails and we expect to see this + // OOM here. the original code messed with a released array within the + // spill code and ended up with a failed assertion. we also expect the + // location of the OOM to be + // org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset memoryManager.markconsequentOOM(2); try { insertNumber(sorter, 1024); @@ -530,9 +531,11 @@ public void testOOMDuringSpill() throws Exception { // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) catch (OutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", - oomStackTrace, - Matchers.containsString("org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + assertThat("expected OutOfMemoryError in " + + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString( + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 1a3e11efe9787..594f07dd780f9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -179,7 +179,8 @@ public int compare( } catch (OutOfMemoryError oom) { // as expected } - // [SPARK-21907] this failed on NPE at org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + // [SPARK-21907] this failed on NPE at + // org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) sorter.free(); // simulate a 'back to back' free. sorter.free(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d6f297c013375..6b0c9d417eeae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** From b034f2565f72aa73c9f0be1e49d148bb4cf05153 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Oct 2017 20:24:51 -0700 Subject: [PATCH 086/712] [SPARK-22026][SQL] data source v2 write path ## What changes were proposed in this pull request? A working prototype for data source v2 write path. The writing framework is similar to the reading framework. i.e. `WriteSupport` -> `DataSourceV2Writer` -> `DataWriterFactory` -> `DataWriter`. Similar to the `FileCommitPotocol`, the writing API has job and task level commit/abort to support the transaction. ## How was this patch tested? new tests Author: Wenchen Fan Closes #19269 from cloud-fan/data-source-v2-write. --- .../spark/sql/sources/v2/WriteSupport.java | 49 ++++ .../sources/v2/writer/DataSourceV2Writer.java | 88 +++++++ .../sql/sources/v2/writer/DataWriter.java | 92 +++++++ .../sources/v2/writer/DataWriterFactory.java | 50 ++++ .../v2/writer/SupportsWriteInternalRow.java | 44 ++++ .../v2/writer/WriterCommitMessage.java | 33 +++ .../apache/spark/sql/DataFrameWriter.scala | 38 ++- .../datasources/v2/DataSourceV2Strategy.scala | 11 +- .../datasources/v2/WriteToDataSourceV2.scala | 133 ++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 69 +++++ .../sources/v2/SimpleWritableDataSource.scala | 249 ++++++++++++++++++ 11 files changed, 842 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java new file mode 100644 index 0000000000000..a8a961598bde3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data to the data source. + */ +@InterfaceStability.Evolving +public interface WriteSupport { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. + * + * @param jobId A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceV2Writer} should + * use this job id to distinguish itself with writers of other jobs. + * @param schema the schema of the data to be written. + * @param mode the save mode which determines what to do when the data are already in this data + * source, please refer to {@link SaveMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java new file mode 100644 index 0000000000000..8d8e33633fb0d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * The writing procedure is: + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). + * 2. For each partition, create the data writer, and write the data of the partition with this + * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If + * exception happens during the writing, call {@link DataWriter#abort()}. + * 3. If all writers are successfully committed, call {@link #commit(WriterCommitMessage[])}. If + * some writers are aborted, or the job failed with an unknown reason, call + * {@link #abort(WriterCommitMessage[])}. + * + * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if + * they want to retry. + * + * Please refer to the document of commit/abort methods for detailed specifications. + * + * Note that, this interface provides a protocol between Spark and data sources for transactional + * data writing, but the transaction here is Spark-level transaction, which may not be the + * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data + * source, but Cassandra may need some more time to reach consistency at storage level. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Writer { + + /** + * Creates a writer factory which will be serialized and sent to executors. + */ + DataWriterFactory createWriterFactory(); + + /** + * Commits this writing job with a list of commit messages. The commit messages are collected from + * successful data writers and are produced by {@link DataWriter#commit()}. If this method + * fails(throw exception), this writing job is considered to be failed, and + * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible + * to data source readers if this method succeeds. + * + * Note that, one partition may have multiple committed data writers because of speculative tasks. + * Spark will pick the first successful one and get its commit message. Implementations should be + * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer + * can commit successfully, or have a way to clean up the data of already-committed writers. + */ + void commit(WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed to write the records and aborted, + * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} + * fails. If this method fails(throw exception), the underlying data source may have garbage that + * need to be cleaned manually, but these garbage should not be visible to data source readers. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(WriterCommitMessage[] messages); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java new file mode 100644 index 0000000000000..14261419af6f6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * responsible for writing data for an input RDD partition. + * + * One Spark task has one exclusive data writer, so there is no thread-safe concern. + * + * {@link #write(Object)} is called for each record in the input RDD partition. If one record fails + * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will + * not be processed. If all records are successfully written, {@link #commit()} is called. + * + * If this data writer succeeds(all records are successfully written and {@link #commit()} + * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an + * exception will be sent to the driver side, and Spark will retry this writing task for some times, + * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * + * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task + * takes too long to finish. Different from retried tasks, which are launched one by one after the + * previous one fails, speculative tasks are running simultaneously. It's possible that one input + * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * and data sources should guarantee that these data writers don't conflict and can work together. + * Implementations can coordinate with driver during {@link #commit()} to make sure only one of + * these data writers can commit successfully. Or implementations can allow all of them to commit + * successfully, and have a way to revert committed data writers without the commit message, because + * Spark only accepts the commit message that arrives first and ignore others. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers + * that mix in {@link SupportsWriteInternalRow}. + */ +@InterfaceStability.Evolving +public interface DataWriter { + + /** + * Writes one record. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + void write(T record); + + /** + * Commits this writer after all records are written successfully, returns a commit message which + * will be send back to driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * The written data should only be visible to data source readers after + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * do the final commitment via {@link WriterCommitMessage}. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + WriterCommitMessage commit(); + + /** + * Aborts this writer if it is failed. Implementations should clean up the data for already + * written records. + * + * This method will only be called if there is one record failed to write, or {@link #commit()} + * failed. + * + * If this method fails(throw exception), the underlying data source may have garbage that need + * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but + * these garbage should not be visible to data source readers. + */ + void abort(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java new file mode 100644 index 0000000000000..f812d102bda1a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * which is responsible for creating and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface DataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + */ + DataWriter createWriter(int partitionId, int attemptNumber); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java new file mode 100644 index 0000000000000..a8e95901f3b07 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. + * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get + * changed in the future Spark versions. + */ + +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsWriteInternalRow extends DataSourceV2Writer { + + @Override + default DataWriterFactory createWriterFactory() { + throw new IllegalStateException( + "createWriterFactory should not be called with SupportsWriteInternalRow."); + } + + DataWriterFactory createInternalRowWriterFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java new file mode 100644 index 0000000000000..082d6b5dc409f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side + * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * implementations. + */ +@InterfaceStability.Evolving +public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c9e45436ed42f..8d95b24c00619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -29,7 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} import org.apache.spark.sql.types.StructType /** @@ -231,12 +234,33 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + cls.newInstance() match { + case ds: WriteSupport => + val options = new DataSourceV2Options(extraOptions.asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + + case _ => throw new AnalysisException(s"$cls does not support data writing.") + } + } else { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f2cda002245e8..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,20 +18,17 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { - // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case DataSourceV2Relation(output, reader) => DataSourceV2ScanExec(output, reader) :: Nil + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala new file mode 100644 index 0000000000000..92c1e1f4a3383 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * The logical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writeTask = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => messages(index) = message + ) + + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } catch { + case cause: Throwable => + logError(s"Data source writer $writer is aborting.") + try { + writer.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source writer $writer failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source writer $writer aborted.") + throw new SparkException("Writing job aborted.", cause) + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + msg + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } +} + +class RowToInternalRowDataWriterFactory( + rowWriterFactory: DataWriterFactory[Row], + schema: StructType) extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new RowToInternalRowDataWriter( + rowWriterFactory.createWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) + extends DataWriter[InternalRow] { + + override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) + + override def commit(): WriterCommitMessage = rowWriter.commit() + + override def abort(): Unit = rowWriter.abort() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f238e565dc2fc..092702a1d5173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,6 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -80,6 +81,74 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple writable data source") { + // TODO: java implementation. + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + // test with different save modes + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("append").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).union(spark.range(10)).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("ignore").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + val e = intercept[Exception] { + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("error").save() + } + assert(e.getMessage.contains("data already exists")) + + // test transaction + val failingUdf = org.apache.spark.sql.functions.udf { + var count = 0 + (id: Long) => { + if (count > 5) { + throw new RuntimeException("testing error") + } + count += 1 + id + } + } + // this input data will fail to read middle way. + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val e2 = intercept[SparkException] { + input.write.format(cls.getName).option("path", path).mode("overwrite").save() + } + assert(e2.getMessage.contains("Writing job aborted")) + // make sure we don't have partial data. + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // test internal row writer + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).option("internal", "true").mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala new file mode 100644 index 0000000000000..6fb60f4d848d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.text.SimpleDateFormat +import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A HDFS based transactional writable data source. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. + */ +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { + + private val schema = new StructType().add("i", "long").add("j", "long") + + class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + override def readSchema(): StructType = schema + + override def createReadTasks(): JList[ReadTask[Row]] = { + val dataPath = new Path(path) + val fs = dataPath.getFileSystem(conf) + if (fs.exists(dataPath)) { + fs.listStatus(dataPath).filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.map { f => + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + }.toList.asJava + } else { + Collections.emptyList() + } + } + } + + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + override def createWriterFactory(): DataWriterFactory[Row] = { + new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val finalPath = new Path(path) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + try { + for (file <- fs.listStatus(jobPath).map(_.getPath)) { + val dest = new Path(finalPath, file.getName) + if(!fs.rename(file, dest)) { + throw new IOException(s"failed to rename($file, $dest)") + } + } + } finally { + fs.delete(jobPath, true) + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + fs.delete(jobPath, true) + } + } + + class InternalRowWriter(jobId: String, path: String, conf: Configuration) + extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + + override def createWriterFactory(): DataWriterFactory[Row] = { + throw new IllegalArgumentException("not expected!") + } + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new Reader(path.toUri.toString, conf) + } + + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) + assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) + + val path = new Path(options.get("path").get()) + val internal = options.get("internal").isPresent + val conf = SparkContext.getActive.get.hadoopConfiguration + val fs = path.getFileSystem(conf) + + if (mode == SaveMode.ErrorIfExists) { + if (fs.exists(path)) { + throw new RuntimeException("data already exists.") + } + } + if (mode == SaveMode.Ignore) { + if (fs.exists(path)) { + return Optional.empty() + } + } + if (mode == SaveMode.Overwrite) { + fs.delete(path, true) + } + + Optional.of(createWriter(jobId, path, conf, internal)) + } + + private def createWriter( + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + val pathStr = path.toUri.toString + if (internal) { + new InternalRowWriter(jobId, pathStr, conf) + } else { + new Writer(jobId, pathStr, conf) + } + } +} + +class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) + extends ReadTask[Row] with DataReader[Row] { + + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ + + override def createReader(): DataReader[Row] = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } +} + +class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[Row] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new SimpleCSVDataWriter(fs, filePath) + } +} + +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { + + private val out = fs.create(file) + + override def write(record: Row): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} + +class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new InternalRowCSVDataWriter(fs, filePath) + } +} + +class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { + + private val out = fs.create(file) + + override def write(record: InternalRow): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} From b84f61cd79a365edd4cc893a1de416c628d9906b Mon Sep 17 00:00:00 2001 From: Eric Perry Date: Thu, 19 Oct 2017 23:57:41 -0700 Subject: [PATCH 087/712] [SQL] Mark strategies with override for clarity. ## What changes were proposed in this pull request? This is a very trivial PR, simply marking `strategies` in `SparkPlanner` with the `override` keyword for clarity since it is overriding `strategies` in `QueryPlanner` two levels up in the class hierarchy. I was reading through the code to learn a bit and got stuck on this fact for a little while, so I figured this may be helpful so that another developer new to the project doesn't get stuck where I was. I did not make a JIRA ticket for this because it is so trivial, but I'm happy to do so to adhere to the contribution guidelines if required. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Perry Closes #19537 from ericjperry/override-strategies. --- .../scala/org/apache/spark/sql/execution/SparkPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b143d44eae17b..74048871f8d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -33,7 +33,7 @@ class SparkPlanner( def numPartitions: Int = conf.numShufflePartitions - def strategies: Seq[Strategy] = + override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( DataSourceV2Strategy :: From 673876b7eadc6f382afc26fc654b0e7916c9ac5c Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 20 Oct 2017 08:28:05 +0100 Subject: [PATCH 088/712] [SPARK-22309][ML] Remove unused param in `LDAModel.getTopicDistributionMethod` ## What changes were proposed in this pull request? Remove unused param in `LDAModel.getTopicDistributionMethod` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19530 from zhengruifeng/lda_bc. --- mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 2 +- .../main/scala/org/apache/spark/mllib/clustering/LDAModel.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 3da29b1c816b1..4bab670cc159f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -458,7 +458,7 @@ abstract class LDAModel private[ml] ( if ($(topicDistributionCol).nonEmpty) { // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) + val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 4ab420058f33d..b8a6e94248421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -371,7 +371,7 @@ class LocalLDAModel private[spark] ( /** * Get a method usable as a UDF for `topicDistributions()` */ - private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + private[spark] def getTopicDistributionMethod: Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape From e2fea8cd6058a807ff4841b496ea345ff0553044 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 20 Oct 2017 09:43:46 +0100 Subject: [PATCH 089/712] [CORE][DOC] Add event log conf. ## What changes were proposed in this pull request? Event Log Server has a total of five configuration parameters, and now the description of the other two configuration parameters on the doc, user-friendly access and use. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19242 from guoxiaolongzte/addEventLogConf. --- docs/configuration.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 7b9e16a382449..d3c358bb74173 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -748,6 +748,20 @@ Apart from these, the following properties are also available, and may be useful finished. + + spark.eventLog.overwrite + false + + Whether to overwrite any existing files. + + + + spark.eventLog.buffer.kb + 100k + + Buffer size in KB to use when writing to output streams. + + spark.ui.enabled true From 16c9cc68c5a70fd50e214f6deba591f0a9ae5cca Mon Sep 17 00:00:00 2001 From: CenYuhai Date: Fri, 20 Oct 2017 09:27:39 -0700 Subject: [PATCH 090/712] [SPARK-21055][SQL] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? spark does not support grouping__id, it has grouping_id() instead. But it is not convenient for hive user to change to spark-sql so this pr is to replace grouping__id with grouping_id() hive user need not to alter their scripts ## How was this patch tested? test with SQLQuerySuite.scala Author: CenYuhai Closes #18270 from cenyuhai/SPARK-21055. --- .../sql/catalyst/analysis/Analyzer.scala | 15 +-- .../sql-tests/inputs/group-analytics.sql | 6 +- .../sql-tests/results/group-analytics.sql.out | 43 ++++--- .../sql/hive/execution/SQLQuerySuite.scala | 110 ++++++++++++++++++ 4 files changed, 148 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8edf575db7969..d6a962a14dc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -293,12 +293,6 @@ class Analyzer( Seq(Seq.empty) } - private def hasGroupingAttribute(expr: Expression): Boolean = { - expr.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u - }.isDefined - } - private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g @@ -452,9 +446,6 @@ class Analyzer( // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case p if p.expressions.exists(hasGroupingAttribute) => - failAnalysis( - s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) @@ -1174,6 +1165,10 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => + withPosition(u) { + Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)() + } case u @ UnresolvedGenerator(name, children) => withPosition(u) { catalog.lookupFunction(name, children) match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 8aff4cb524199..9721f8c60ebce 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -38,11 +38,11 @@ SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) GROUP BY CUBE(course, year); SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- GROUPING/GROUPING_ID in having clause SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; @@ -54,7 +54,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index ce7a16a4d0c81..3439a05727f95 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -223,22 +223,29 @@ grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 16 -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 16 schema -struct<> +struct -- !query 16 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 -- !query 17 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year -- !query 17 schema struct -- !query 17 output -Java NULL NULL NULL +Java NULL dotNET NULL @@ -263,10 +270,13 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 20 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 -- !query 20 schema -struct<> +struct -- !query 20 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java NULL +NULL 2012 +NULL 2013 +NULL NULL +dotNET NULL -- !query 21 @@ -322,12 +332,19 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 25 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 +Java 2013 +dotNET 2012 +dotNET 2013 +Java NULL +dotNET NULL +NULL 2012 +NULL 2013 +NULL NULL -- !query 26 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 60935c3e85c43..2476a440ad82c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1388,6 +1388,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for Rollup #2") { checkAnswer(sql( """ @@ -1409,6 +1422,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for Rollup #3") { checkAnswer(sql( """ @@ -1430,6 +1464,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for CUBE #1") { checkAnswer(sql( "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), @@ -1443,6 +1498,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + test("SPARK-8976 Wrong Result for CUBE #2") { checkAnswer(sql( """ @@ -1464,6 +1532,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-8976 Wrong Result for GroupingSet") { checkAnswer(sql( """ @@ -1485,6 +1574,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + ignore("SPARK-10562: partition by column with mixed case name") { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 568763bafb7acfcf5921d6492034d1f6f87875e2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Oct 2017 12:32:45 -0700 Subject: [PATCH 091/712] [INFRA] Close stale PRs. Closes #19541 Closes #19542 From b8624b06e5d531ebc14acb05da286f96f4bc9515 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Oct 2017 12:44:30 -0700 Subject: [PATCH 092/712] [SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This is a follow-up of #18732. This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped udf implicitly. ## How was this patch tested? Exisiting tests. Author: Takuya UESHIN Closes #19517 from ueshin/issues/SPARK-20396/fup2. --- .../spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 1 + python/pyspark/sql/functions.py | 33 ++++++++++------ python/pyspark/sql/group.py | 14 ++++--- python/pyspark/sql/tests.py | 37 ++++++++++++++++++ python/pyspark/worker.py | 39 ++++++++----------- .../logical/pythonLogicalOperators.scala | 9 +++-- .../spark/sql/RelationalGroupedDataset.scala | 7 ++-- .../execution/python/ExtractPythonUDFs.scala | 6 ++- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 ++++++- .../python/BatchEvalPythonExecSuite.scala | 2 +- 13 files changed, 114 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3688a149443c1..d417303bb147d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0c81eaa..a0adeed994456 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,6 +86,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc12c3b7a162..9bc374b93a433 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-at-a-time UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self.vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2099,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2121,15 +2130,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType - wrapper.vectorized = self.vectorized + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: @@ -2137,7 +2146,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf @@ -2145,9 +2154,9 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) @@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) @@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()): .. note:: The user-defined function must be deterministic. """ - return _create_udf(f, returnType=returnType, vectorized=True) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 817d0bc83bb77..e11388d604312 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType, UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -235,11 +236,13 @@ def apply(self, udf): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - from pyspark.sql.functions import pandas_udf + import inspect # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: - raise ValueError("The argument to apply must be a pandas_udf") + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ + or len(inspect.getargspec(udf.func).args) != 1: + raise ValueError("The argument to apply must be a 1-arg pandas_udf") if not isinstance(udf.returnType, StructType): raise ValueError("The returnType of the pandas_udf must be a StructType") @@ -268,8 +271,9 @@ def wrapped(*cols): return [(result[result.columns[i]], arrow_type) for i, arrow_type in enumerate(arrow_return_types)] - wrapped_udf_obj = pandas_udf(wrapped, returnType) - udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + udf_obj = UserDefinedFunction( + wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) + udf_column = udf_obj(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bac2ef84ae7a7..685eebcafefba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3492,6 +3501,18 @@ def normalize(pdf): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3517,9 +3538,25 @@ def test_wrong_args(self): df.groupby('id').apply(sum(df.v)) with self.assertRaisesRegexp(ValueError, 'pandas_udf'): df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'returnType'): df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eb6d48688dc0a..5e100e0a9a95d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type, StructType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,28 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs - if isinstance(return_type, StructType): - return lambda *a: f(*a) - else: - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type(return_type) - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8abab24bc9b44..254687ec00880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** * This is needed because output attributes are considered `references` when * passed through the constructor. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 33ec3a27110a8..6b45790d5ff6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python user-defined function to each group of data. + * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results * for all groups are combined into a new [[DataFrame]]. @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the vectorized python udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e3f952e221d53..d6825369f7378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5bb38ba5..5ed88ada428cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e5be59c..9c07c7638de57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80acf5c23..b2fe6c300846a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-at-a-time UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1f88c70..95b21fc9f16ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF) From d9f286d261c6ee9e8dcb46e78d4666318ea25af2 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Fri, 20 Oct 2017 20:58:55 -0700 Subject: [PATCH 093/712] [SPARK-22326][SQL] Remove unnecessary hashCode and equals methods ## What changes were proposed in this pull request? Plan equality should be computed by `canonicalized`, so we can remove unnecessary `hashCode` and `equals` methods. ## How was this patch tested? Existing tests. Author: Zhenhua Wang Closes #19539 from wzhfy/remove_equals. --- .../apache/spark/sql/catalyst/catalog/interface.scala | 11 ----------- .../sql/execution/datasources/LogicalRelation.scala | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 975b084aa6188..1dbae4d37d8f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -22,8 +22,6 @@ import java.util.Date import scala.collection.mutable -import com.google.common.base.Objects - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -440,15 +438,6 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override def equals(relation: Any): Boolean = relation match { - case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(tableMeta.identifier, output) - } - override lazy val canonicalized: HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 17a61074d3b5c..3e98cb28453a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -34,17 +34,6 @@ case class LogicalRelation( override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { - // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _, _, isStreaming) => - relation == otherRelation && output == l.output && isStreaming == l.isStreaming - case _ => false - } - - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(relation, output) - } - // Only care about relation when canonicalizing. override lazy val canonicalized: LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), From d8cada8d1d3fce979a4bc1f9879593206722a3b9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:05:45 -0700 Subject: [PATCH 094/712] [SPARK-20331][SQL][FOLLOW-UP] Add a SQLConf for enhanced Hive partition pruning predicate pushdown ## What changes were proposed in this pull request? This is a follow-up PR of https://github.com/apache/spark/pull/17633. This PR is to add a conf `spark.sql.hive.advancedPartitionPredicatePushdown.enabled`, which can be used to turn the enhancement off. ## How was this patch tested? Add a test case Author: gatorsmile Closes #19547 from gatorsmile/Spark20331FollowUp. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++++++ .../spark/sql/hive/client/HiveShim.scala | 29 ++++++++++++++++++ .../spark/sql/hive/client/FiltersSuite.scala | 30 +++++++++++++++---- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 618d4a0d6148a..4cfe53b2c115b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -173,6 +173,13 @@ object SQLConf { .intConf .createWithDefault(4) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = + buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") + .internal() + .doc("When true, advanced partition predicate pushdown into Hive metastore is enabled.") + .booleanConf + .createWithDefault(true) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -1092,6 +1099,9 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def advancedPartitionPredicatePushdownEnabled: Boolean = + getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index cde20da186acd..5c1ff2b76fdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,6 +585,35 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { + if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { + convertComplexFilters(table, filters) + } else { + convertBasicFilters(table, filters) + } + } + + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + lazy val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 031c1a5ec0ec3..19765695fbcb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -26,13 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * A set of tests for the filter conversion logic used when pushing partition pruning into the * metastore */ -class FiltersSuite extends SparkFunSuite with Logging { +class FiltersSuite extends SparkFunSuite with Logging with PlanTest { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") @@ -72,10 +74,28 @@ class FiltersSuite extends SparkFunSuite with Logging { private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + } + + test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) { + val filters = + (Literal(1) === a("intcol", IntegerType) || + Literal(2) === a("intcol", IntegerType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "(1 = intcol or 2 = intcol)") + } else { + assert(converted.isEmpty) + } } } } From a763607e4fc24f4dc0f455b67a63acba5be1c80a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Oct 2017 10:07:31 -0700 Subject: [PATCH 095/712] [SPARK-21055][SQL][FOLLOW-UP] replace grouping__id with grouping_id() ## What changes were proposed in this pull request? Simplifies the test cases that were added in the PR https://github.com/apache/spark/pull/18270. ## How was this patch tested? N/A Author: gatorsmile Closes #19546 from gatorsmile/backportSPARK-21055. --- .../sql/hive/execution/SQLQuerySuite.scala | 306 ++++++------------ 1 file changed, 104 insertions(+), 202 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2476a440ad82c..1cf1c5cd5a472 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1376,223 +1376,125 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-8976 Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping__id FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) - } - - test("SPARK-21055 replace grouping__id: Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping__id AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } ignore("SPARK-10562: partition by column with mixed case name") { From ff8de99a1c7b4a291e661cd0ad12748f4321e43d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 22 Oct 2017 02:22:35 +0900 Subject: [PATCH 096/712] [SPARK-22302][INFRA] Remove manual backports for subprocess and print explicit message for < Python 2.7 ## What changes were proposed in this pull request? Seems there was a mistake - missing import for `subprocess.call`, while refactoring this script a long ago, which should be used for backports of some missing functions in `subprocess`, specifically in < Python 2.7. Reproduction is: ``` cd dev && python2.6 ``` ``` >>> from sparktestsupport import shellutils >>> shellutils.subprocess_check_call("ls") Traceback (most recent call last): File "", line 1, in File "sparktestsupport/shellutils.py", line 46, in subprocess_check_call retcode = call(*popenargs, **kwargs) NameError: global name 'call' is not defined ``` For Jenkins logs, please see https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/3950/console Since we dropped the Python 2.6.x support, looks better we remove those workarounds and print out explicit error messages in order to reduce the efforts to find out the root causes for such cases, for example, `https://github.com/apache/spark/pull/19513#issuecomment-337406734`. ## How was this patch tested? Manually tested: ``` ./dev/run-tests ``` ``` Python versions prior to 2.7 are not supported. ``` ``` ./dev/run-tests-jenkins ``` ``` Python versions prior to 2.7 are not supported. ``` Author: hyukjinkwon Closes #19524 from HyukjinKwon/SPARK-22302. --- dev/run-tests | 6 ++++++ dev/run-tests-jenkins | 7 ++++++- dev/sparktestsupport/shellutils.py | 31 ++---------------------------- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 257d1e8d50bb4..9cf93d000d0ea 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,10 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f41f1ac79e381..03fd6ff0fba40 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -25,5 +25,10 @@ FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" -export PATH=/home/anaconda/bin:$PATH +PYTHON_VERSION_CHECK=$(python -c 'import sys; print(sys.version_info < (2, 7, 0))') +if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then + echo "Python versions prior to 2.7 are not supported." + exit -1 +fi + exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index 05af87189b18d..c7644da88f770 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -21,35 +21,8 @@ import subprocess import sys - -if sys.version_info >= (2, 7): - subprocess_check_output = subprocess.check_output - subprocess_check_call = subprocess.check_call -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - # backported from subprocess module in Python 2.7 - def subprocess_check_call(*popenargs, **kwargs): - retcode = call(*popenargs, **kwargs) - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise CalledProcessError(retcode, cmd) - return 0 +subprocess_check_output = subprocess.check_output +subprocess_check_call = subprocess.check_call def exit_from_command_with_retcode(cmd, retcode): From ca2a780e7c4c4df2488ef933241c6e65264f8d3c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 21 Oct 2017 18:01:45 -0700 Subject: [PATCH 097/712] [SPARK-21929][SQL] Support `ALTER TABLE table_name ADD COLUMNS(..)` for ORC data source ## What changes were proposed in this pull request? When [SPARK-19261](https://issues.apache.org/jira/browse/SPARK-19261) implements `ALTER TABLE ADD COLUMNS`, ORC data source is omitted due to SPARK-14387, SPARK-16628, and SPARK-18355. Now, those issues are fixed and Spark 2.3 is [using Spark schema to read ORC table instead of ORC file schema](https://github.com/apache/spark/commit/e6e36004afc3f9fc8abea98542248e9de11b4435). This PR enables `ALTER TABLE ADD COLUMNS` for ORC data source. ## How was this patch tested? Pass the updated and added test cases. Author: Dongjoon Hyun Closes #19545 from dongjoon-hyun/SPARK-21929. --- .../spark/sql/execution/command/tables.scala | 3 +- .../sql/execution/command/DDLSuite.scala | 90 ++++++++++--------- .../sql/hive/execution/HiveDDLSuite.scala | 8 ++ 3 files changed, 58 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8d95ca6921cf8..38f91639c0422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -235,11 +235,10 @@ case class AlterTableAddColumnsCommand( DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" - // OrcFileFormat can not handle difference between user-specified schema and - // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. // Hive type is already considered as hive serde table, so the logic will not // come in here. case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4ed2cecc5faff..21a2c62929146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2202,56 +2202,64 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + protected def testAddColumn(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + + protected def testAddColumnPartitioned(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int) USING $provider") - sql("INSERT INTO t1 VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 is null"), - Seq(Row(1, null)) - ) - - sql("INSERT INTO t1 VALUES (3, 2)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 2"), - Seq(Row(3, 2)) - ) - } + testAddColumn(provider) } } supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - partitioned - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") - sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null, 2)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 is null"), - Seq(Row(1, null, 2)) - ) - sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 = 3"), - Seq(Row(2, 3, 1)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 1"), - Seq(Row(2, 3, 1)) - ) - } + testAddColumnPartitioned(provider) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 02e26bbe876a0..d3465a641a1a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -166,6 +166,14 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } + + test("alter datasource table add columns - orc") { + testAddColumn("orc") + } + + test("alter datasource table add columns - partitioned - orc") { + testAddColumnPartitioned("orc") + } } class HiveDDLSuite From 57accf6e3965ff69adc4408623916c5003918235 Mon Sep 17 00:00:00 2001 From: Steven Rand Date: Mon, 23 Oct 2017 09:43:45 +0800 Subject: [PATCH 098/712] [SPARK-22319][CORE] call loginUserFromKeytab before accessing hdfs In `SparkSubmit`, call `loginUserFromKeytab` before attempting to make RPC calls to the NameNode. I manually tested this patch by: 1. Confirming that my Spark application failed to launch with the error reported in https://issues.apache.org/jira/browse/SPARK-22319. 2. Applying this patch and confirming that the app no longer fails to launch, even when I have not manually run `kinit` on the host. Presumably we also want integration tests for secure clusters so that we catch this sort of thing. I'm happy to take a shot at this if it's feasible and someone can point me in the right direction. Author: Steven Rand Closes #19540 from sjrand/SPARK-22319. Change-Id: Ic306bfe7181107fbcf92f61d75856afcb5b6f761 --- .../org/apache/spark/deploy/SparkSubmit.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 135bbe93bf28e..b7e6d0ea021a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -342,6 +342,22 @@ object SparkSubmit extends CommandLineUtils with Logging { val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (args.principal != null) { + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + } + // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull @@ -641,22 +657,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } - } - } - if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { setRMPrincipal(sysProps) } From 5a5b6b78517b526771ee5b579d56aa1daa4b3ef1 Mon Sep 17 00:00:00 2001 From: Kohki Nishio Date: Mon, 23 Oct 2017 09:55:46 -0700 Subject: [PATCH 099/712] [SPARK-22303][SQL] Handle Oracle specific jdbc types in OracleDialect TIMESTAMP (-101), BINARY_DOUBLE (101) and BINARY_FLOAT (100) are handled in OracleDialect ## What changes were proposed in this pull request? When a oracle table contains columns whose type is BINARY_FLOAT or BINARY_DOUBLE, spark sql fails to load a table with SQLException ``` java.sql.SQLException: Unsupported type 101 at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.org$apache$spark$sql$execution$datasources$jdbc$JdbcUtils$$getCatalystType(JdbcUtils.scala:235) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$8.apply(JdbcUtils.scala:292) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.getSchema(JdbcUtils.scala:291) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$.resolveTable(JDBCRDD.scala:64) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation.(JDBCRelation.scala:113) at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:47) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:306) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:178) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:146) ``` ## How was this patch tested? I updated a UT which covers type conversion test for types (-101, 100, 101), on top of that I tested this change against actual table with those columns and it was able to read and write to the table. Author: Kohki Nishio Closes #19548 from taroplus/oracle_sql_types_101. --- .../sql/jdbc/OracleIntegrationSuite.scala | 43 +++++++++++++++--- .../datasources/jdbc/JdbcUtils.scala | 1 - .../apache/spark/sql/jdbc/OracleDialect.scala | 44 +++++++++++-------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +++ 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 7680ae3835132..90343182712ed 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import java.util.Properties import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo import testImplicits._ override val db = new DatabaseOnDocker { - override val imageName = "wnameless/oracle-xe-11g:14.04.4" + override val imageName = "wnameless/oracle-xe-11g:16.04" override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() conn.prepareStatement( - "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); - conn.commit(); + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.commit() } test("SPARK-16625 : Importing Oracle numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) val rows = df.collect() assert(rows.size == 1) val row = rows(0) @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getInt(1).equals(1)) assert(values.getBoolean(2).equals(false)) } + + test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") { + val tableName = "oracle_types" + val schema = StructType(Seq( + StructField("d", DoubleType, true), + StructField("f", FloatType, true))) + val props = new Properties() + + // write it back to the table (append mode) + val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f))) + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props) + + // read records from oracle_types + val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties) + val rows = dfRead.collect() + assert(rows.size == 1) + + // check data types + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DoubleType)) + assert(types(1).equals(FloatType)) + + // check values + val values = rows(0) + assert(values.getDouble(0) === 1.1) + assert(values.getFloat(1) === 2.2f) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71133666b3249..9debc4ff82748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -230,7 +230,6 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType - case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3b44c1de93a61..e3f106c41c7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { + private[jdbc] val BINARY_FLOAT = 100 + private[jdbc] val BINARY_DOUBLE = 101 + private[jdbc] val TIMESTAMPTZ = -101 override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L - size match { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts - // this to NUMERIC with -127 scale - // Not sure if there is a more robust way to identify the field as a float (or other - // numeric types that do not specify a scale. - case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case _ => None - } - } else { - None + sqlType match { + case Types.NUMERIC => + val scale = if (null != md) md.build().getLong("scale") else 0L + size match { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + case _ => None + } + case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT + case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34205e0b2bf08..167b3e0190026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) == + Some(FloatType)) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) == + Some(DoubleType)) + assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) == + Some(TimestampType)) } test("table exists query by jdbc dialect") { From f6290aea24efeb238db88bdaef4e24d50740ca4c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 23 Oct 2017 23:02:36 +0100 Subject: [PATCH 100/712] [SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate ## What changes were proposed in this pull request? The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD. Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more. This leads to a huge method in codegen and causes error: ``` org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB. ``` Besides, huge generated methods also result in performance regression. In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints. ## How was this patch tested? Test by an added test case and existing tests. Author: Zhenhua Wang Closes #19506 from wzhfy/change_forIntervals_typedAgg. --- .../ApproxCountDistinctForIntervals.scala | 97 ++++++++++--------- ...ApproxCountDistinctForIntervalsSuite.scala | 34 +++---- ...xCountDistinctForIntervalsQuerySuite.scala | 61 ++++++++++++ 3 files changed, 130 insertions(+), 62 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 096d1b35a8620..d4421ca20a9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -22,9 +22,10 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform /** * This function counts the approximate number of distinct values (ndv) in @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with ExpectsInputTypes { - - def this(child: Expression, endpointsExpression: Expression) = { - this( - child = child, - endpointsExpression = endpointsExpression, - relativeSD = 0.05, - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) - } + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals( private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length /** Allocate enough words to store all registers. */ - override lazy val aggBufferAttributes: Seq[AttributeReference] = { - Seq.tabulate(totalNumWords) { i => - AttributeReference(s"MS[$i]", LongType)() - } + override def createAggregationBuffer(): Array[Long] = { + Array.fill(totalNumWords)(0L) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - /** Fill all words with zeros. */ - override def initialize(buffer: InternalRow): Unit = { - var word = 0 - while (word < totalNumWords) { - buffer.setLong(mutableAggBufferOffset + word, 0) - word += 1 - } - } - - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Array[Long], input: InternalRow): Array[Long] = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals( // endpoints are sorted into ascending order already if (endpoints.head > doubleValue || endpoints.last < doubleValue) { // ignore if the value is out of the whole range - return + return buffer } val hllppIndex = findHllppIndex(doubleValue) - val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp - hllppArray(hllppIndex).update(buffer, offset, value, child.dataType) + val offset = hllppIndex * numWordsPerHllpp + hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType) } + buffer } // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value. @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals( } } - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = { for (i <- hllppArray.indices) { hllppArray(i).merge( - buffer1 = buffer1, - buffer2 = buffer2, - offset1 = mutableAggBufferOffset + i * numWordsPerHllpp, - offset2 = inputAggBufferOffset + i * numWordsPerHllpp) + buffer1 = LongArrayInternalRow(buffer1), + buffer2 = LongArrayInternalRow(buffer2), + offset1 = i * numWordsPerHllpp, + offset2 = i * numWordsPerHllpp) } + buffer1 } - override def eval(buffer: InternalRow): Any = { + override def eval(buffer: Array[Long]): Any = { val ndvArray = hllppResults(buffer) // If the endpoints contains multiple elements with the same value, // we set ndv=1 for intervals between these elements. @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals( new GenericArrayData(ndvArray) } - def hllppResults(buffer: InternalRow): Array[Long] = { + def hllppResults(buffer: Array[Long]): Array[Long] = { val ndvArray = new Array[Long](hllppArray.length) for (i <- ndvArray.indices) { - ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp) + ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp) } ndvArray } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(inputAggBufferOffset = newInputAggBufferOffset) + } override def children: Seq[Expression] = Seq(child, endpointsExpression) @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals( override def dataType: DataType = ArrayType(LongType) override def prettyName: String = "approx_count_distinct_for_intervals" + + override def serialize(obj: Array[Long]): Array[Byte] = { + val byteArray = new Array[Byte](obj.length * 8) + var i = 0 + while (i < obj.length) { + Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i)) + i += 1 + } + byteArray + } + + override def deserialize(bytes: Array[Byte]): Array[Long] = { + assert(bytes.length % 8 == 0) + val length = bytes.length / 8 + val longArray = new Array[Long](length) + var i = 0 + while (i < length) { + longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8) + i += 1 + } + longArray + } + + private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow { + override def getLong(offset: Int): Long = array(offset) + override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d6c38c3608bf8..73f18d4feef3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) wrongColumnTypes.foreach { dataType => - val wrongColumn = new ApproxCountDistinctForIntervals( + val wrongColumn = ApproxCountDistinctForIntervals( AttributeReference("a", dataType)(), endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) assert( @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { }) } - var wrongEndpoints = new ApproxCountDistinctForIntervals( + var wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = Literal(0.5d)) assert( @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case _ => false }) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The endpoints provided must be constant literals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { private def createEstimator[T]( endpoints: Array[T], dt: DataType, - rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) - val buffer = createBuffer(aggFunc) - (aggFunc, input, buffer) - } - - private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { - val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) - aggFunc.initialize(buffer) - buffer + (aggFunc, input, aggFunc.createAggregationBuffer()) } test("merging ApproxCountDistinctForIntervals instances") { val (aggFunc, input, buffer1a) = createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType) - val buffer1b = createBuffer(aggFunc) - val buffer2 = createBuffer(aggFunc) + val buffer1b = aggFunc.createAggregationBuffer() + val buffer2 = aggFunc.createAggregationBuffer() // Add the lower half to `buffer1a`. var i = 0 @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } // Check if the buffers are equal. - assert(buffer2 == buffer1a, "Buffers should be equal") + assert(buffer2.sameElements(buffer1a), "Buffers should be equal") } test("test findHllppIndex(value) for values in the range") { @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2) } + test("round trip serialization") { + val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType) + val longArray = (1L to 100L).toArray + val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray)) + assert(roundtrip.sameElements(longArray)) + } + test("basic operations: update, merge, eval...") { val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0) val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala new file mode 100644 index 0000000000000..c7d86bc955d67 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height + // histogram usually contains hundreds of buckets. So we need to test + // ApproxCountDistinctForIntervals with large number of endpoints + // (the number of endpoints == the number of buckets + 1). + test("test ApproxCountDistinctForIntervals with large number of endpoints") { + val table = "approx_count_distinct_for_intervals_tbl" + withTable(table) { + (1 to 100000).toDF("col").createOrReplaceTempView(table) + // percentiles of 0, 0.001, 0.002 ... 0.999, 1 + val endpoints = (0 to 1000).map(_ * 100000 / 1000) + + // Since approx_count_distinct_for_intervals is not a public function, here we do + // the computation by constructing logical plan. + val relation = spark.table(table).logicalPlan + val attr = relation.output.find(_.name == "col").get + val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) + val aggExpr = aggFunc.toAggregateExpression() + val namedExpr = Alias(aggExpr, aggExpr.toString)() + val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) + .executedPlan.executeTake(1).head + val ndvArray = ndvsRow.getArray(0).toLongArray() + assert(endpoints.length == ndvArray.length + 1) + + // Each bucket has 100 distinct values. + val expectedNdv = 100 + for (i <- ndvArray.indices) { + val ndv = ndvArray(i) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") + } + } + } +} From 884d4f95f7ebfaa9d8c57cf770d10a2c6ab82d62 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 23 Oct 2017 17:21:49 -0700 Subject: [PATCH 101/712] [SPARK-21912][SQL][FOLLOW-UP] ORC/Parquet table should not create invalid column names ## What changes were proposed in this pull request? During [SPARK-21912](https://issues.apache.org/jira/browse/SPARK-21912), we skipped testing 'ADD COLUMNS' on ORC tables due to ORC limitation. Since [SPARK-21929](https://issues.apache.org/jira/browse/SPARK-21929) is resolved now, we can test both `ORC` and `PARQUET` completely. ## How was this patch tested? Pass the updated test case. Author: Dongjoon Hyun Closes #19562 from dongjoon-hyun/SPARK-21912-2. --- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1cf1c5cd5a472..39e918c3d5209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2031,8 +2031,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-21912 ORC/Parquet table should not create invalid column names") { Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => - withTable("t21912") { - Seq("ORC", "PARQUET").foreach { source => + Seq("ORC", "PARQUET").foreach { source => + withTable("t21912") { val m = intercept[AnalysisException] { sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") }.getMessage @@ -2049,15 +2049,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage assert(m3.contains(s"contains invalid character(s)")) } - } - // TODO: After SPARK-21929, we need to check ORC, too. - Seq("PARQUET").foreach { source => sql(s"CREATE TABLE t21912(`col` INT) USING $source") - val m = intercept[AnalysisException] { + val m4 = intercept[AnalysisException] { sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") }.getMessage - assert(m.contains(s"contains invalid character(s)")) + assert(m4.contains(s"contains invalid character(s)")) } } } From d9798c834f3fed060cfd18a8d38c398cb2efcc82 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Oct 2017 12:44:47 +0900 Subject: [PATCH 102/712] [SPARK-22313][PYTHON] Mark/print deprecation warnings as DeprecationWarning for deprecated APIs ## What changes were proposed in this pull request? This PR proposes to mark the existing warnings as `DeprecationWarning` and print out warnings for deprecated functions. This could be actually useful for Spark app developers. I use (old) PyCharm and this IDE can detect this specific `DeprecationWarning` in some cases: **Before** **After** For console usage, `DeprecationWarning` is usually disabled (see https://docs.python.org/2/library/warnings.html#warning-categories and https://docs.python.org/3/library/warnings.html#warning-categories): ``` >>> import warnings >>> filter(lambda f: f[2] == DeprecationWarning, warnings.filters) [('ignore', <_sre.SRE_Pattern object at 0x10ba58c00>, , <_sre.SRE_Pattern object at 0x10bb04138>, 0), ('ignore', None, , None, 0)] ``` so, it won't actually mess up the terminal much unless it is intended. If this is intendedly enabled, it'd should as below: ``` >>> import warnings >>> warnings.simplefilter('always', DeprecationWarning) >>> >>> from pyspark.sql import functions >>> functions.approxCountDistinct("a") .../spark/python/pyspark/sql/functions.py:232: DeprecationWarning: Deprecated in 2.1, use approx_count_distinct instead. "Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) ... ``` These instances were found by: ``` cd python/pyspark grep -r "Deprecated" . grep -r "deprecated" . grep -r "deprecate" . ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #19535 from HyukjinKwon/deprecated-warning. --- python/pyspark/ml/util.py | 8 ++- python/pyspark/mllib/classification.py | 2 +- python/pyspark/mllib/evaluation.py | 6 +-- python/pyspark/mllib/regression.py | 8 +-- python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/functions.py | 18 +++++++ python/pyspark/streaming/flume.py | 14 ++++- python/pyspark/streaming/kafka.py | 72 ++++++++++++++++++++++---- 8 files changed, 110 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67772910c0d38..c3c47bd79459a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -175,7 +175,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jwrite.context(sqlContext._ssql_ctx) return self @@ -256,7 +258,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jread.context(sqlContext._ssql_ctx) return self diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e04eeb2b60d71..cce703d432b5a 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -311,7 +311,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ warnings.warn( "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or " - "LogisticRegressionWithLBFGS.") + "LogisticRegressionWithLBFGS.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index fc2a0b3b5038a..2cd1da3fbf9aa 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -234,7 +234,7 @@ def precision(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("precision") else: return self.call("precision", float(label)) @@ -246,7 +246,7 @@ def recall(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("recall") else: return self.call("recall", float(label)) @@ -259,7 +259,7 @@ def fMeasure(self, label=None, beta=None): if beta is None: if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("fMeasure") else: return self.call("fMeasure", label) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 1b66f5b51044b..ea107d400621d 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -278,7 +278,8 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, A condition which decides iteration termination. (default: 0.001) """ - warnings.warn("Deprecated in 2.0.0. Use ml.regression.LinearRegression.") + warnings.warn( + "Deprecated in 2.0.0. Use ml.regression.LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -421,7 +422,8 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. " - "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.") + "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", + DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -566,7 +568,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. " "Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for " - "LinearRegression.") + "LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38b01f0011671..c0b574e2b93a1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -130,6 +130,8 @@ def registerTempTable(self, name): .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. """ + warnings.warn( + "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning) self._jdf.createOrReplaceTempView(name) @since(2.0) @@ -1308,6 +1310,7 @@ def unionAll(self, other): .. note:: Deprecated in 2.0, use :func:`union` instead. """ + warnings.warn("Deprecated in 2.0, use union instead.", DeprecationWarning) return self.union(other) @since(2.3) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc374b93a433..0d40368c9cd6e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -21,6 +21,7 @@ import math import sys import functools +import warnings if sys.version < "3": from itertools import imap as map @@ -44,6 +45,14 @@ def _(col): return _ +def _wrap_deprecated_function(func, message): + """ Wrap the deprecated function to print out deprecation warnings""" + def _(col): + warnings.warn(message, DeprecationWarning) + return func(col) + return functools.wraps(func)(_) + + def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): @@ -207,6 +216,12 @@ def _(): """returns the relative rank (i.e. percentile) of rows within a window partition.""", } +# Wraps deprecated functions (keys) with the messages (values). +_functions_deprecated = { + 'toDegrees': 'Deprecated in 2.1, use degrees instead.', + 'toRadians': 'Deprecated in 2.1, use radians instead.', +} + for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): @@ -219,6 +234,8 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _message in _functions_deprecated.items(): + globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) del _name, _doc @@ -227,6 +244,7 @@ def approxCountDistinct(col, rsd=None): """ .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ + warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) return approx_count_distinct(col, rsd) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index 2fed5940b31ea..5a975d050b0d8 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -54,8 +54,13 @@ def createStream(ssc, hostname, port, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) @@ -82,8 +87,13 @@ def createPollingStream(ssc, addresses, :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] ports = [] diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 4af4135c81958..fdb9308604489 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD @@ -56,8 +58,13 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if kafkaParams is None: kafkaParams = dict() kafkaParams.update({ @@ -105,8 +112,13 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :return: A DStream object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if fromOffsets is None: fromOffsets = dict() if not isinstance(topics, list): @@ -159,8 +171,13 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :return: An RDD object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if leaders is None: leaders = dict() if not isinstance(kafkaParams, dict): @@ -229,7 +246,8 @@ class OffsetRange(object): """ Represents a range of offsets from a single Kafka TopicAndPartition. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, fromOffset, untilOffset): @@ -240,6 +258,10 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.fromOffset = fromOffset @@ -270,7 +292,8 @@ class TopicAndPartition(object): """ Represents a specific topic and partition for Kafka. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition): @@ -279,6 +302,10 @@ def __init__(self, topic, partition): :param topic: Kafka topic name. :param partition: Kafka partition id. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._topic = topic self._partition = partition @@ -303,7 +330,8 @@ class Broker(object): """ Represent the host and port info for a Kafka broker. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, host, port): @@ -312,6 +340,10 @@ def __init__(self, host, port): :param host: Broker's hostname. :param port: Broker's port. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._host = host self._port = port @@ -323,10 +355,15 @@ class KafkaRDD(RDD): """ A Python wrapper of KafkaRDD, to provide additional information on normal RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jrdd, ctx, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) RDD.__init__(self, jrdd, ctx, jrdd_deserializer) def offsetRanges(self): @@ -345,10 +382,15 @@ class KafkaDStream(DStream): """ A Python wrapper of KafkaDStream - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jdstream, ssc, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) DStream.__init__(self, jdstream, ssc, jrdd_deserializer) def foreachRDD(self, func): @@ -383,10 +425,15 @@ class KafkaTransformedDStream(TransformedDStream): """ Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, prev, func): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) TransformedDStream.__init__(self, prev, func) @property @@ -405,7 +452,8 @@ class KafkaMessageAndMetadata(object): """ Kafka message and metadata information. Including topic, partition, offset and message - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, offset, key, message): @@ -419,6 +467,10 @@ def __init__(self, topic, partition, offset, key, message): :param message: actual message payload of this Kafka message, the return data is undecoded bytearray. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.offset = offset From c30d5cfc7117bdadd63bf730e88398139e0f65f4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 24 Oct 2017 08:46:22 +0100 Subject: [PATCH 103/712] [SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache ## What changes were proposed in this pull request? This PR generates the Java code to directly get a value for a column in `ColumnVector` without using an iterator (e.g. at lines 54-69 in the generated code example) for table cache (e.g. `dataframe.cache`). This PR improves runtime performance by eliminating data copy from column-oriented storage to `InternalRow` in a `SpecificColumnarIterator` iterator for primitive type. Another PR will support primitive type array. Benchmark result: **1.2x** ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Int Sum with IntDelta cache: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ InternalRow codegen 731 / 812 43.0 23.2 1.0X ColumnVector codegen 616 / 772 51.0 19.6 1.2X ``` Benchmark program ``` intSumBenchmark(sqlContext, 1024 * 1024 * 30) def intSumBenchmark(sqlContext: SQLContext, values: Int): Unit = { import sqlContext.implicits._ val benchmarkPT = new Benchmark("Int Sum with IntDelta cache", values, 20) Seq(("InternalRow", "false"), ("ColumnVector", "true")).foreach { case (str, value) => withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) { // tentatively added for benchmarking val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1).toDF().cache() dfPassThrough.count() // force to create df.cache() benchmarkPT.addCase(s"$str codegen") { iter => dfPassThrough.agg(sum("value")).collect } dfPassThrough.unpersist(true) } } benchmarkPT.run() } ``` Motivating example ``` val dsInt = spark.range(3).cache dsInt.count // force to build cache dsInt.filter(_ > 0).collect ``` Generated code ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inmemorytablescan_input; /* 009 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows; /* 010 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_scanTime; /* 011 */ private long inmemorytablescan_scanTime1; /* 012 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch inmemorytablescan_batch; /* 013 */ private int inmemorytablescan_batchIdx; /* 014 */ private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector inmemorytablescan_colInstance0; /* 015 */ private UnsafeRow inmemorytablescan_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder inmemorytablescan_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter inmemorytablescan_rowWriter; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows; /* 019 */ private UnsafeRow filter_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ inmemorytablescan_input = inputs[0]; /* 031 */ inmemorytablescan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 032 */ inmemorytablescan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 033 */ inmemorytablescan_scanTime1 = 0; /* 034 */ inmemorytablescan_batch = null; /* 035 */ inmemorytablescan_batchIdx = 0; /* 036 */ inmemorytablescan_colInstance0 = null; /* 037 */ inmemorytablescan_result = new UnsafeRow(1); /* 038 */ inmemorytablescan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(inmemorytablescan_result, 0); /* 039 */ inmemorytablescan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(inmemorytablescan_holder, 1); /* 040 */ filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 041 */ filter_result = new UnsafeRow(1); /* 042 */ filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 0); /* 043 */ filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 1); /* 044 */ /* 045 */ } /* 046 */ /* 047 */ protected void processNext() throws java.io.IOException { /* 048 */ if (inmemorytablescan_batch == null) { /* 049 */ inmemorytablescan_nextBatch(); /* 050 */ } /* 051 */ while (inmemorytablescan_batch != null) { /* 052 */ int inmemorytablescan_numRows = inmemorytablescan_batch.numRows(); /* 053 */ int inmemorytablescan_localEnd = inmemorytablescan_numRows - inmemorytablescan_batchIdx; /* 054 */ for (int inmemorytablescan_localIdx = 0; inmemorytablescan_localIdx < inmemorytablescan_localEnd; inmemorytablescan_localIdx++) { /* 055 */ int inmemorytablescan_rowIdx = inmemorytablescan_batchIdx + inmemorytablescan_localIdx; /* 056 */ int inmemorytablescan_value = inmemorytablescan_colInstance0.getInt(inmemorytablescan_rowIdx); /* 057 */ /* 058 */ boolean filter_isNull = false; /* 059 */ /* 060 */ boolean filter_value = false; /* 061 */ filter_value = inmemorytablescan_value > 1; /* 062 */ if (!filter_value) continue; /* 063 */ /* 064 */ filter_numOutputRows.add(1); /* 065 */ /* 066 */ filter_rowWriter.write(0, inmemorytablescan_value); /* 067 */ append(filter_result); /* 068 */ if (shouldStop()) { inmemorytablescan_batchIdx = inmemorytablescan_rowIdx + 1; return; } /* 069 */ } /* 070 */ inmemorytablescan_batchIdx = inmemorytablescan_numRows; /* 071 */ inmemorytablescan_batch = null; /* 072 */ inmemorytablescan_nextBatch(); /* 073 */ } /* 074 */ inmemorytablescan_scanTime.add(inmemorytablescan_scanTime1 / (1000 * 1000)); /* 075 */ inmemorytablescan_scanTime1 = 0; /* 076 */ } /* 077 */ /* 078 */ private void inmemorytablescan_nextBatch() throws java.io.IOException { /* 079 */ long getBatchStart = System.nanoTime(); /* 080 */ if (inmemorytablescan_input.hasNext()) { /* 081 */ org.apache.spark.sql.execution.columnar.CachedBatch inmemorytablescan_cachedBatch = (org.apache.spark.sql.execution.columnar.CachedBatch)inmemorytablescan_input.next(); /* 082 */ inmemorytablescan_batch = org.apache.spark.sql.execution.columnar.InMemoryRelation$.MODULE$.createColumn(inmemorytablescan_cachedBatch); /* 083 */ /* 084 */ inmemorytablescan_numOutputRows.add(inmemorytablescan_batch.numRows()); /* 085 */ inmemorytablescan_batchIdx = 0; /* 086 */ inmemorytablescan_colInstance0 = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) inmemorytablescan_batch.column(0); org.apache.spark.sql.execution.columnar.ColumnAccessor$.MODULE$.decompress(inmemorytablescan_cachedBatch.buffers()[0], (org.apache.spark.sql.execution.vectorized.WritableColumnVector) inmemorytablescan_colInstance0, org.apache.spark.sql.types.DataTypes.IntegerType, inmemorytablescan_cachedBatch.numRows()); /* 087 */ /* 088 */ } /* 089 */ inmemorytablescan_scanTime1 += System.nanoTime() - getBatchStart; /* 090 */ } /* 091 */ } ``` ## How was this patch tested? Add test cases into `DataFrameTungstenSuite` and `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #18747 from kiszk/SPARK-20822a. --- .../sql/execution/ColumnarBatchScan.scala | 3 - .../sql/execution/WholeStageCodegenExec.scala | 24 ++++---- .../execution/columnar/ColumnAccessor.scala | 8 +++ .../columnar/InMemoryTableScanExec.scala | 57 +++++++++++++++++-- .../spark/sql/DataFrameTungstenSuite.scala | 36 ++++++++++++ .../execution/WholeStageCodegenSuite.scala | 32 +++++++++++ 6 files changed, 141 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..eb01e126bcbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType */ private[sql] trait ColumnarBatchScan extends CodegenSupport { - val inMemoryTableScan: InMemoryTableScanExec = null - def vectorTypes: Option[Seq[String]] = None override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1aaaf896692d1..e37d133ff336a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp object WholeStageCodegenExec { val PIPELINE_DURATION_METRIC = "duration" + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = { + numOfNestedFields(dataType) > conf.wholeStageMaxNumFields + } } /** @@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } - private def numOfNestedFields(dataType: DataType): Int = dataType match { - case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum - case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) - case a: ArrayType => numOfNestedFields(a.elementType) - case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) - case _ => 1 - } - private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = - numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + WholeStageCodegenExec.isTooManyFields(conf, plan.schema) val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema)) !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 24c8ac81420cb..445933d98e9d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -163,4 +163,12 @@ private[sql] object ColumnAccessor { throw new RuntimeException("Not support non-primitive type now") } } + + def decompress( + array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int): + Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = ColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 139da1c519da2..43386e7a03c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.types._ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def vectorTypes: Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from ColumnVector + */ + override val supportCodegen: Boolean = { + // In the initial implementation, for ease of review + // support only primitive data types and # of fields is less than wholeStageMaxNumFields + relation.schema.fields.forall(f => f.dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType => true + case _ => false + }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) + } + + private val columnIndices = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + private val relationSchema = relation.schema.toArray + + private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) + + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val rowCount = cachedColumnarBatch.numRows + val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val columnarBatch = new ColumnarBatch( + columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + columnarBatch.setNumRows(rowCount) + + for (i <- 0 until attributes.length) { + ColumnAccessor.decompress( + cachedColumnarBatch.buffers(columnIndices(i)), + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + columnarBatchSchema.fields(i).dataType, rowCount) + } + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + assert(supportCodegen) + val buffers = relation.cachedColumnBuffers + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + } override def output: Seq[Attribute] = attributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index fe6ba83b4cbfb..0881212a64de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } + + test("primitive data type accesses in persist data") { + val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, null) + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data))) + val df = spark.createDataFrame(rdd, StructType(schemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + test("access cache multiple times") { + val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache + df0.count + val df1 = df0.filter("x > 1") + checkAnswer(df1, Seq(Row(2), Row(3))) + val df2 = df0.filter("x > 2") + checkAnswer(df2, Row(3)) + + val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache + for (_ <- 0 to 2) { + val df11 = df10.filter("x > 5") + checkAnswer(df11, Row(6)) + } + } + + test("access only some column of the all of columns") { + val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d") + df.cache + df.count + assert(df.filter("d < 3").count == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 098e4cfeb15b2..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { + import testImplicits._ + + val dsInt = spark.range(3).cache + dsInt.count + val dsIntFilter = dsInt.filter(_ > 0) + val planInt = dsIntFilter.queryExecution.executedPlan + assert(planInt.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntFilter.collect() === Array(1, 2)) + + // cache for string type is not supported for InMemoryTableScanExec + val dsString = spark.range(3).map(_.toString).cache + dsString.count + val dsStringFilter = dsString.filter(_ == "1") + val planString = dsStringFilter.queryExecution.executedPlan + assert(planString.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec]).isDefined + ) + assert(dsStringFilter.collect() === Array("1")) + } + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) From 8beeaed66bde0ace44495b38dc967816e16b3464 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 24 Oct 2017 13:56:10 +0100 Subject: [PATCH 104/712] [SPARK-21936][SQL][FOLLOW-UP] backward compatibility test framework for HiveExternalCatalog ## What changes were proposed in this pull request? Adjust Spark download in test to use Apache mirrors and respect its load balancer, and use Spark 2.1.2. This follows on a recent PMC list thread about removing the cloudfront download rather than update it further. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19564 from srowen/SPARK-21936.2. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 305f5b533d592..5f8c9d5799662 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -53,7 +53,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def downloadSpark(version: String): Unit = { import scala.sys.process._ - val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! @@ -142,7 +144,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") protected var spark: SparkSession = _ From 3f5ba968c5af7911a2f6c452500b6a629a3de8db Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Oct 2017 09:11:52 -0700 Subject: [PATCH 105/712] [SPARK-22301][SQL] Add rule to Optimizer for In with not-nullable value and empty list ## What changes were proposed in this pull request? For performance reason, we should resolve in operation on an empty list as false in the optimizations phase, ad discussed in #19522. ## How was this patch tested? Added UT cc gatorsmile Author: Marco Gaido Author: Marco Gaido Closes #19523 from mgaido91/SPARK-22301. --- .../sql/catalyst/optimizer/expressions.scala | 7 +++++-- .../sql/catalyst/optimizer/OptimizeInSuite.scala | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 273bc6ce27c5d..523b53b39d6b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -169,13 +169,16 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { /** * Optimize IN predicates: - * 1. Removes literal repetitions. - * 2. Replaces [[In (value, seq[Literal])]] with optimized version + * 1. Converts the predicate to false when the list is empty and + * the value is not nullable. + * 2. Removes literal repetitions. + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index eaad1e32a8aba..d7acd139225cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -175,4 +175,20 @@ class OptimizeInSuite extends PlanTest { } } } + + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + + "when value is not nullable") { + val originalQuery = + testRelation + .where(In(Literal("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(Literal(false)) + .analyze + + comparePlans(optimized, correctAnswer) + } } From bc1e76632ddec8fc64726086905183d1f312bca4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 06:33:44 +0100 Subject: [PATCH 106/712] [SPARK-22348][SQL] The table cache providing ColumnarBatch should also do partition batch pruning ## What changes were proposed in this pull request? We enable table cache `InMemoryTableScanExec` to provide `ColumnarBatch` now. But the cached batches are retrieved without pruning. In this case, we still need to do partition batch pruning. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19569 from viirya/SPARK-22348. --- .../columnar/InMemoryTableScanExec.scala | 70 ++++++++++--------- .../columnar/InMemoryColumnarQuerySuite.scala | 27 ++++++- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 43386e7a03c32..2ae3f35eb1da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,7 +78,7 @@ case class InMemoryTableScanExec( override def inputRDDs(): Seq[RDD[InternalRow]] = { assert(supportCodegen) - val buffers = relation.cachedColumnBuffers + val buffers = filteredCachedBatches() // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) @@ -180,19 +180,11 @@ case class InMemoryTableScanExec( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - + private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => @@ -201,35 +193,49 @@ case class InMemoryTableScanExec( schema) partitionFilter.initialize(index) + // Do partition batch pruning if enabled + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } else { + cachedBatchIterator + } + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => relOutput.indexOf(a.exprId) -> a.dataType }.unzip - // Do partition batch pruning if enabled - val cachedBatchesToScan = - if (inMemoryPartitionPruningEnabled) { - cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter.eval(cachedBatch.stats)) { - logDebug { - val statsString = schemaIndex.map { case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") - s"Skipping partition based on stats $statsString" - } - false - } else { - true - } - } - } else { - cachedBatchIterator - } - // update SQL metrics - val withMetrics = cachedBatchesToScan.map { batch => + val withMetrics = cachedBatchIterator.map { batch => if (enableAccumulators) { readBatches.add(1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2f249c850a088..e662e294228db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -454,4 +454,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) } + + test("SPARK-22348: table cache should do partition batch pruning") { + Seq("true", "false").foreach { enabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) { + val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") + df1.unpersist() + df1.cache() + + // Push predicate to the cached table. + val df2 = df1.where("y = 3") + + val planBeforeFilter = df2.queryExecution.executedPlan.collect { + case f: FilterExec => f.child + } + assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + + val execPlan = if (enabled == "true") { + WholeStageCodegenExec(planBeforeFilter.head) + } else { + planBeforeFilter.head + } + assert(execPlan.executeCollectPublic().length == 0) + } + } + } } From 524abb996abc9970d699623c13469ea3b6d2d3fc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 24 Oct 2017 22:59:46 -0700 Subject: [PATCH 107/712] [SPARK-21101][SQL] Catch IllegalStateException when CREATE TEMPORARY FUNCTION ## What changes were proposed in this pull request? It must `override` [`public StructObjectInspector initialize(ObjectInspector[] argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L70) when create a UDTF. If you `override` [`public StructObjectInspector initialize(StructObjectInspector argOIs)`](https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java#L49), `IllegalStateException` will throw. per: [HIVE-12377](https://issues.apache.org/jira/browse/HIVE-12377). This PR catch `IllegalStateException` and point user to `override` `public StructObjectInspector initialize(ObjectInspector[] argOIs)`. ## How was this patch tested? unit tests Source code and binary jar: [SPARK-21101.zip](https://github.com/apache/spark/files/1123763/SPARK-21101.zip) These two source code copy from : https://github.com/apache/hive/blob/release-2.0.0/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFStack.java Author: Yuming Wang Closes #18527 from wangyum/SPARK-21101. --- .../spark/sql/hive/HiveSessionCatalog.scala | 11 ++++++-- .../src/test/resources/SPARK-21101-1.0.jar | Bin 0 -> 7439 bytes .../sql/hive/execution/SQLQuerySuite.scala | 26 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/SPARK-21101-1.0.jar diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index b256ffc27b199..1f11adbd4f62e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,8 +94,15 @@ private[sql] class HiveSessionCatalog( } } catch { case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") + val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val errorMsg = + if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + s"$noHandlerMsg\nPlease make sure your function overrides " + + "`public StructObjectInspector initialize(ObjectInspector[] args)`." + } else { + noHandlerMsg + } + val analysisException = new AnalysisException(errorMsg) analysisException.setStackTrace(e.getStackTrace) throw analysisException } diff --git a/sql/hive/src/test/resources/SPARK-21101-1.0.jar b/sql/hive/src/test/resources/SPARK-21101-1.0.jar new file mode 100644 index 0000000000000000000000000000000000000000..768b2334db5c3aa8b1e4186af5fde86120d3231b GIT binary patch literal 7439 zcmb7J1z1#D*XAOFgmkw^#~@uINJ=*%LyqJK2t#+*05TE|HH4DV-Q7rsgi4oO8m=&+ zzz3gKz2EQkKKEbq%sJB?WmMT^==sBelT+RTUu57@i7{b9iv1 zQk!>DU~$cfTY0#TTLmbCb$vDaK>|5f8?#3}GD@37MO()ujkB1P7MD0)K%2~mWI+4q z@{Y2AvvS+A{HISWgnp4F`pUw6j<|P&J)E6QcuaWEzJ>L3^ca_6IXGE=5Bz5j+&?|Q zj$m^e%YSer`d>$9N3fIaKe&_qox3yC?jIo3zk=96-2N#t=6}RldRUsfxk;BBJ-@-KWJ^N#rinTf~4%X6gKbOjmDORleAB(l_w?F9>Q9~XJ@2E!6-9*Gl0>o zg@v=KaIdBFt>}1BJs8`8E}_Q2xeK9bsT^54w)1BiX(aZN9c3kyC&D@yWku->4&O&^ zx9Y)^B^Wwt*H%N72iv2hn@APtT1YzFJ5lr|h@q2U6qqs!tfFKRI|LCs)54NMP?+K^ zu`wbEma1?1ak6-@gkSwBZR>tVemCtCEIn{*|?j8tz4td~>TBI^Z{YnJb7)X8!xb6LfD8)fHC)(G&T}nuonup|g_zYt-?nkFHOP|`O zfvKM~21pUp5n{Q3u)SV+6@q`in;hHO&@4+#oIm@xyLvLfipDA3^hTX5Cdchm_K5hY z?WH8yyLWv*l6WC=%LiwMpWbOWH3mw)SFW)FO?fJUa? zRDl4`LpMl4mmk}`%&^1*T;;c5j2YI} z)X#G#nWwHePY{!4PmxK(i6CAt-Lw`0flCv%;yudCGi^^SMjw&ieURn7*2~|i`hg9@ z2b1piI~_@-C-X3J&@AVT_nuN_zFXMTr@cFy9L`a|ocY4Ajo_Fv|=$ z#qki151<*P>H+{-J83ZTf)-ZP0&5x%VL`R~HEm4pw4ylY0Z}-|;MTUt+htq>vf*ii z(I0XR2NTg%KD)cPNRq<}PO+8dwS@H5Hd9s#cA+Y74k?i2kP;PDLjjWNcpi75Nv62#&m!$q!VL|RtgL`K$BiZI;7HSewdKMiDEJ$;k)r{wfZ#+6s6d{u+OK)Ur7xt^9f^}8S74LE8dI#P=XjW0@j4$ zl4wfUIsu74G+j4Hhc{;Lx@BvVpqFScSl}1JJ=DpFBg3 z7{lL_#?Zwkljpz6c_qv*WI0e}Qj#;#kW6TkzT1|N<0JQ41hf13^VVLKZd_3=Ic3I{ z>rMuqPSuqDA{w3o6*oL#P41eA7fK4${(522x0dwWNtIlJ5sC?v*_s>2<2RE9+vN^v ztWFh>hXpsQ8~VKxsG%#y&=2Dl-f{g=Q`U~j*nS1?(W@V1lOl7DlBI&4bUtf%!%HC` zQa=}bx53BIMqAsLx^Us+iaoUp;E3T#?fwy(l%W28DiwPtmB9?-KH0DjigTmrfHou6 zvZd}^sHO<@%%A(zZlJz*8vPy5aMr*>a!VGwG+NKMM{B=dJ_ ztMX@QxDmJ+7nLu>qJCGQ--H@Z#jT3}J|5;(rIRHLf;O{_EALJ&g6(=&nxtzymsV}a zPT4_JzM~p<(MzPFXyHu8%AJ-MjAQIipu$djgXfB_kO9bY3IQ6H=jh|3Liy zn)yT3V*;y#Q67J9-}8O;;i(Nc-Wxg`+BhCalVoIelU8DP9L>Y`4{2=g za=02VYLcti-qL1_>D7y@c!fuLqL^9thPPx#Olr=1842IHCu4k#1QJ!m?m>eS)Cw2fWfI2RaRPPTyA>(kG(nou=B-+B>8VYe`h<$#uM)lD;;N_~`b~|8ujmd!_gZuLVfkdGH5}IT$ zV>woe4WrA0N3{=wrON?I<#e*k*vHTtQQ+*5U?M1Ht_L2XuJ6EdhCb#bE9j3fCADP~ z-l2F~EW_<3d)n1LHu@a9>rG>PR`6?LO;P3O^J3z1cDPQn_)o`{nM#48gIf4|M|E-U zq9{&?-@n`G_IQ^tSt)154eS_(uj2Dx0V8zbbp6?b8r$Va;##99*F+$ceTeTdgC)!PrGp^_^8=^;_pQDV%mqvI-l zg^t>#M~)>=dn)U63zPKuv&+K@R2qlwe^R49VkJFqfTTH%Jkd*k`q~qkYR%$DzB=J0 z;GHacCRCiLti63iqFp34G6ZLu=eQN#%q-NAyxAu6wp~m78kVywpI@7g{(l)?2V+oaKej1+S+9we<#yOXLqZb~ORj@l zX6V5XZwm|IIPwq8ak7WgQdvW`Fr>8=@mcQH4p}}8-vnZPjaGIM7P;hmRuivoV*nJg zXU-LxtEq%RoaYtm!m?-@iH}#KV-HmahAJx3x}Gg)AWpp->-}9`yA9jK4D$uzvEfql zr#$BwUQhTKBYifd=``cGGuX652$>*ezZ;X42JEn>;~8u=)`7eTh^yl6+#Iqgv2{#U zt+wB_nc7snm8Zcnj#V5AFS< zsy|u_Id5w%>f1@&>KGHd9wt%in``MchaULQ$LQva?UqDN^;xdvjrDa_JZs%2%%d|7 zc~ygT>q!-+q?pZ+_&aO}bZBLYK&m4o&!M#3Ec*lXJLPCP>K&Scs!8tZsa*b=0lb&t z6xs?8@EgHR{Vf>gkExCf@LP!Q)GIr6>DT#nIXiKbI#F~80p}hP5-J5!PgwBD<^q&F zLH=gd>@0z5Oj!hvo#Ql!T333LDJ6C@#$9=v56L`9paN zXy}B{tD4>lJB=eM=22*3e!QG2E&uxj^V5*5mznbLBNyZPyIY>Xc=Of`-Gv#my;n(Z zUgT1W85og106p&HW852*`0`N<>-+OQBYgBTY%a}Q@wbKnIQyJ&>?v6oZ)A;Fbu|xDX3y25n zX|J>5>1DCudaLNaw#)7rugT*?;kOhndW$NtI`)0Mp#~>8*AacS_!$)#a0>C%XeT&8 z+|$poRrMQ)B~m@{bG6o|q#$j2%D{o*o5xvULua}OXD(UHLnULNoB2jG1xD6ro#kgp3_HpoB8QGXM<03M2k2VPu7OY{btVwLN+&$ zFF2BU;|~f%T)McqJW>A3Odo92Z3(RHN2QnqdqJyAJILOwgc*ieeVcCk&w9L6*M$5<} z#PnHn)|WVsqB8j*U`hU3w(ZWe?d_LK)W53>DwJlPkL&~*`v*GxX0PpnhbL&Jw(?eD z3cAQSg*B%t zGcwZ{$F-V-{pMTpO`Q)uP^Y{)PnQ72TyHRpiodO=C8Gq)6c=$qN7bW`t#Ys8CK@?QUPv@LUYV8SM#lY8kF zK75w5*G^C)A2%t}2WR9!Lc@90bK#^BTS8rm%%_#S*B5@>CP`<%POxy*x+W~iK4h1uxFTnlXz4`g;+B)A9&BC&p=~^ zYHCZ=EwUM#f|}Jhd*4$zD?xZgWzfg<0`P60@^f{K;Buexy?SU5cC-9Zh5Ff<{AXu# z2Ya5!7T+sZKVxqF5~HIjt*rv&=i}q$`;Ny#dvIB|^zto}eT$BU=5kr_Fm{;WCH37ajo>sf7umUIbI8w#LHvls9MDxr(fF!yGR7`%Cj_;JVp5M61kbIU z>SlQFejKW2z;2MVi?X?&^C-wvjf6Kzp$X7*^W7$h6xu%XBRIzGO&zhq{^ILhDHvl&D^Mc1 zPL-zqa$g`O@wI`2Tv+wM?yO-JUyn2qK~K!}LC=(wfzbzqpf88;3~`zuy1_Unj)Fs6 zC4q7PO^}(WAq^s>*fDnq>UpwsZY!Y?XAoqDTU!NM5-g|K29>MDz(qwK=h1KzNzfO^ zt^m~%SrG8C!_R8EvCkuSeOL%xdSFqE@=8O#x;g4?n!tU{@*p%8&g3x-S3MnlUFV4l z(K+5YL`TSm7~yjsuDj%q6E-^S_wj?X2z$zdJ;td?7++%|gWb)@tu4;TjRF8ccb#`S z-Kxk$H@&K!J^;=<$DJ)An~JuVmPSRnO96_93;3_t!$S}D?mY|sPJc1Ilivb{jk-W(9_T~hN< zTp`lzjbJm=P(A}3Eh0{+-B8}LPtnlE){4^2E`Zf`3X~snre3odX9v_)K7a+}V6%Ex zq27Y`#RVgcv50O$1YG+G`oS(6A!~cx^AVD4P!Q^Wr`o6^v+2`BT2*UWF&Vu#`Mjp>jiMM9Z74wRyi>*SEKsC&Yy$ zS6`M7zLIo(b>N5ZCL3v+L|j=WMGHc#a}| zor^K;%e8Ma`0JO3S*gU@qsxh5arx){*Av9i$-&Xm$pvERtmCFJqykjtR#MYaEYRIo z<5n43(Z;`i*re=CdP9}$23}D>-=t={n5bNPU)w0Mkh?lxR6uHEN^>NDtC_vEs!gC< zSWHU)HdiQDc|NlMkgpztotSo!a<$mj!p6p|4{%xW{|%k=+OHLh%b+if5N${L>im5L z`T_fU1>?&3%CZC5E|1?~f6+63F#Nn0U5O1}?W*CH%ge9Q!>>cV5+A)Vy!gSB z^^ZJ%l`_6D{=Jy-jWauN?O!tfUuomVK>z*M_m373m!tTXVf)(qSKasDXa5oN_ZGf> zi~nfg=QrspZGJUEw2KSLOL6V{yZ9}MetZ6(!B=VY2iW#!@b{$pYmTe5`hz3mmmL3_ zWWVwKeVYBoxD&YWulW8 true, "udtf_stack2" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack1 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val cnt = + sql("SELECT udtf_stack1(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')").count() + assert(cnt === 2) + + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack2 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql("SELECT udtf_stack2(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')") + } + assert( + e.getMessage.contains("public StructObjectInspector initialize(ObjectInspector[] args)")) + } + } + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { val table = "test21721" withTable(table) { From 427359f077ad469d78c97972d021535f30a1e418 Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Tue, 24 Oct 2017 23:02:11 -0700 Subject: [PATCH 108/712] [SPARK-13947][SQL] The error message from using an invalid column reference is not clear ## What changes were proposed in this pull request? Rewritten error message for clarity. Added extra information in case of attribute name collision, hinting the user to double-check referencing two different tables ## How was this patch tested? No functional changes, only final message has changed. It has been tested manually against the situation proposed in the JIRA ticket. Automated tests in repository pass. This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro Author: Ruben Berenguel Montoro Author: Ruben Berenguel Closes #17100 from rberenguel/SPARK-13947-error-message. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 19 ++++++++++++--- .../analysis/AnalysisErrorSuite.scala | 23 +++++++++++++------ .../invalid-correlation.sql.out | 2 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d9906bb6e6ede..b5e8bdd79869e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -272,10 +272,23 @@ trait CheckAnalysis extends PredicateHelper { case o if o.children.nonEmpty && o.missingInput.nonEmpty => val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") + val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + + s"from $input in operator ${operator.simpleString}." - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") + val resolver = plan.conf.resolver + val attrsWithSameName = o.missingInput.filter { missing => + o.inputSet.exists(input => resolver(missing.name, input.name)) + } + + val msg = if (attrsWithSameName.nonEmpty) { + val sameNames = attrsWithSameName.map(_.name).mkString(",") + s"$msgForMissingAttributes Attribute(s) with the same name appear in the " + + s"operation: $sameNames. Please check if the right attribute(s) are used." + } else { + msgForMissingAttributes + } + + failAnalysis(msg) case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 884e113537c93..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -408,16 +408,25 @@ class AnalysisErrorSuite extends AnalysisTest { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept // LongType, DoubleType, and DecimalType. We use LongType as the type of a. - val plan = - Aggregate( - Nil, - Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", LongType)(exprId = ExprId(2)))) + val attrA = AttributeReference("a", LongType)(exprId = ExprId(1)) + val otherA = AttributeReference("a", LongType)(exprId = ExprId(2)) + val attrC = AttributeReference("c", LongType)(exprId = ExprId(3)) + val aliases = Alias(sum(attrA), "b")() :: Alias(sum(attrC), "d")() :: Nil + val plan = Aggregate( + Nil, + aliases, + LocalRelation(otherA)) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) + val resolved = s"${attrA.toString},${attrC.toString}" + + val errorMsg = s"Resolved attribute(s) $resolved missing from ${otherA.toString} " + + s"in operator !Aggregate [${aliases.mkString(", ")}]. " + + s"Attribute(s) with the same name appear in the operation: a. " + + "Please check if the right attribute(s) are used." + + assertAnalysisError(plan, errorMsg :: Nil) } test("error test for self-join") { diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index e4b1a2dbc675c..2586f26f71c35 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -63,7 +63,7 @@ WHERE t1a IN (SELECT min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); +Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]).; -- !query 5 From 6c6950839da991bd41accdb8fb03fbc3b588c1e4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 25 Oct 2017 12:51:20 +0100 Subject: [PATCH 109/712] [SPARK-22322][CORE] Update FutureAction for compatibility with Scala 2.12 Future ## What changes were proposed in this pull request? Scala 2.12's `Future` defines two new methods to implement, `transform` and `transformWith`. These can be implemented naturally in Spark's `FutureAction` extension and subclasses, but, only in terms of the new methods that don't exist in Scala 2.11. To support both at the same time, reflection is used to implement these. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #19561 from srowen/SPARK-22322. --- .../scala/org/apache/spark/FutureAction.scala | 59 ++++++++++++++++++- pom.xml | 2 +- .../FlatMapGroupsWithStateSuite.scala | 3 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1034fdcae8e8c..036c9a60630ea 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -89,7 +89,11 @@ trait FutureAction[T] extends Future[T] { */ override def value: Option[Try[T]] - // These two methods must be implemented in Scala 2.12, but won't be used by Spark + // These two methods must be implemented in Scala 2.12. They're implemented as a no-op here + // and then filled in with a real implementation in the two subclasses below. The no-op exists + // here so that those implementations can declare "override", necessary in 2.12, while working + // in 2.11, where the method doesn't exist in the superclass. + // After 2.11 support goes away, remove these two: def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] = throw new UnsupportedOperationException() @@ -113,6 +117,42 @@ trait FutureAction[T] extends Future[T] { } +/** + * Scala 2.12 defines the two new transform/transformWith methods mentioned above. Impementing + * these for 2.12 in the Spark class here requires delegating to these same methods in an + * underlying Future object. But that only exists in 2.12. But these methods are only called + * in 2.12. So define helper shims to access these methods on a Future by reflection. + */ +private[spark] object FutureAction { + + private val transformTryMethod = + try { + classOf[Future[_]].getMethod("transform", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private val transformWithTryMethod = + try { + classOf[Future[_]].getMethod("transformWith", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private[spark] def transform[T, S]( + future: Future[T], + f: (Try[T]) => Try[S], + executor: ExecutionContext): Future[S] = + transformTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + + private[spark] def transformWith[T, S]( + future: Future[T], + f: (Try[T]) => Future[S], + executor: ExecutionContext): Future[S] = + transformWithTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + +} + /** * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include @@ -153,6 +193,18 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) + + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) } @@ -246,6 +298,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T]) def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform(p.future, f, e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith(p.future, f, e) } diff --git a/pom.xml b/pom.xml index b9c972855204a..2d59f06811a82 100644 --- a/pom.xml +++ b/pom.xml @@ -2692,7 +2692,7 @@ scala-2.12 - 2.12.3 + 2.12.4 2.12 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index af08186aadbb0..b906393a379ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} -import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1201,7 +1200,7 @@ object FlatMapGroupsWithStateSuite { } catch { case u: UnsupportedOperationException => return - case _ => + case _: Throwable => throw new TestFailedException("Unexpected exception when trying to get watermark", 20) } throw new TestFailedException("Could get watermark when not expected", 20) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c53889bb8566c..cc693909270f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -744,7 +744,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(returnedValue === expectedReturnValue, "Returned value does not match expected") } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc()) true // If the control reached here, then everything worked as expected } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d6e15cfdd2723..ab7c8558321c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -139,7 +139,7 @@ private[streaming] class FileBasedWriteAheadLog( def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) - CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) + CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, () => reader.close()) } if (!closeFileAfterWrite) { logFilesToRead.iterator.map(readFile).flatten.asJava From 1051ebec70bf05971ddc80819d112626b1f1614f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 25 Oct 2017 16:31:58 +0100 Subject: [PATCH 110/712] [SPARK-20783][SQL][FOLLOW-UP] Create ColumnVector to abstract existing compressed column ## What changes were proposed in this pull request? Removed one unused method. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19508 from viirya/SPARK-20783-followup. --- .../apache/spark/sql/execution/columnar/ColumnAccessor.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 445933d98e9d4..85c36b7da9498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -63,9 +63,6 @@ private[columnar] abstract class BasicColumnAccessor[JvmType]( } protected def underlyingBuffer = buffer - - def getByteBuffer: ByteBuffer = - buffer.duplicate.order(ByteOrder.nativeOrder()) } private[columnar] class NullColumnAccessor(buffer: ByteBuffer) From 3d43a9f939764ec265de945921d1ecf2323ca230 Mon Sep 17 00:00:00 2001 From: liuxian Date: Wed, 25 Oct 2017 21:34:00 +0530 Subject: [PATCH 111/712] [SPARK-22349] In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## What changes were proposed in this pull request? In on-heap mode, when allocating memory from pool,we should fill memory with `MEMORY_DEBUG_FILL_CLEAN_VALUE` ## How was this patch tested? added unit tests Author: liuxian Closes #19572 from 10110346/MEMORY_DEBUG. --- .../spark/unsafe/memory/HeapMemoryAllocator.java | 3 +++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 355748238540b..cc9cc429643ad 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -56,6 +56,9 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final MemoryBlock memory = blockReference.get(); if (memory != null) { assert (memory.size() == size); + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { + memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + } return memory; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4ae49d82efa29..4b141339ec816 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -66,10 +66,21 @@ public void overlappingCopyMemory() { public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1); - MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + MemoryAllocator.HEAP.free(onheap1); + Assert.assertEquals( + Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); + MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Assert.assertEquals( + Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); From 6ea8a56ca26a7e02e6574f5f763bb91059119a80 Mon Sep 17 00:00:00 2001 From: Andrea zito Date: Wed, 25 Oct 2017 10:10:24 -0700 Subject: [PATCH 112/712] [SPARK-21991][LAUNCHER] Fix race condition in LauncherServer#acceptConnections ## What changes were proposed in this pull request? This patch changes the order in which _acceptConnections_ starts the client thread and schedules the client timeout action ensuring that the latter has been scheduled before the former get a chance to cancel it. ## How was this patch tested? Due to the non-deterministic nature of the patch I wasn't able to add a new test for this issue. Author: Andrea zito Closes #19217 from nivox/SPARK-21991. --- .../apache/spark/launcher/LauncherServer.java | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..454bc7a7f924d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); + } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); } + + clientThread.start(); } } catch (IOException ioe) { if (running) { From d212ef14be7c2864cc529e48a02a47584e46f7a5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Oct 2017 13:53:01 -0700 Subject: [PATCH 113/712] [SPARK-22341][YARN] Impersonate correct user when preparing resources. The bug was introduced in SPARK-22290, which changed how the app's user is impersonated in the AM. The changed missed an initialization function that needs to be run as the app owner (who has the right credentials to read from HDFS). Author: Marcelo Vanzin Closes #19566 from vanzin/SPARK-22341. --- .../spark/deploy/yarn/ApplicationMaster.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f6167235f89e4..244d912b9f3aa 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -97,9 +97,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private val client = ugi.doAs(new PrivilegedExceptionAction[YarnRMClient]() { - def run: YarnRMClient = new YarnRMClient() - }) + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -178,7 +176,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. - private val localResources = { + private val localResources = doAsUser { logInfo("Preparing Local resources") val resources = HashMap[String, LocalResource]() @@ -240,9 +238,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } final def run(): Int = { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - def run: Unit = runImpl() - }) + doAsUser { + runImpl() + } exitCode } @@ -790,6 +788,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } + private def doAsUser[T](fn: => T): T = { + ugi.doAs(new PrivilegedExceptionAction[T]() { + override def run: T = fn + }) + } + } object ApplicationMaster extends Logging { From b377ef133cdc38d49b460b2cc6ece0b5892804cc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Oct 2017 22:15:44 +0100 Subject: [PATCH 114/712] [SPARK-22227][CORE] DiskBlockManager.getAllBlocks now tolerates temp files ## What changes were proposed in this pull request? Prior to this commit getAllBlocks implicitly assumed that the directories managed by the DiskBlockManager contain only the files corresponding to valid block IDs. In reality, this assumption was violated during shuffle, which produces temporary files in the same directory as the resulting blocks. As a result, calls to getAllBlocks during shuffle were unreliable. The fix could be made more efficient, but this is probably good enough. ## How was this patch tested? `DiskBlockManagerSuite` Author: Sergei Lebedev Closes #19458 from superbobry/block-id-option. --- .../scala/org/apache/spark/storage/BlockId.scala | 16 +++++++++++++--- .../apache/spark/storage/DiskBlockManager.scala | 11 ++++++++++- .../org/apache/spark/storage/BlockIdSuite.scala | 9 +++------ .../spark/storage/DiskBlockManagerSuite.scala | 7 +++++++ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index a441baed2800e..7ac2c71c18eb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -95,6 +96,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -104,10 +109,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -122,9 +128,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index f0c521b00b583..ff4755833a916 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -35,13 +35,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } @@ -139,6 +134,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 5) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("temp shuffle") { @@ -151,6 +147,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 1) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("test") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) From 841f1d776f420424c20d99cf7110d06c73f9ca20 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 25 Oct 2017 14:31:36 -0700 Subject: [PATCH 115/712] [SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic) ## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu Closes #19558 from WeichenXu123/fix_nb_test_seed. --- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9730dd68a3b27..0d3adf993383f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -335,6 +335,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { From 5433be44caecaeef45ed1fdae10b223c698a9d14 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 25 Oct 2017 14:41:02 -0700 Subject: [PATCH 116/712] [SPARK-21991][LAUNCHER][FOLLOWUP] Fix java lint ## What changes were proposed in this pull request? Fix java lint ## How was this patch tested? Run `./dev/lint-java` Author: Andrew Ash Closes #19574 from ash211/aash/fix-java-lint. --- .../main/java/org/apache/spark/launcher/LauncherServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 454bc7a7f924d..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -235,7 +235,7 @@ public void run() { synchronized (clients) { clients.add(clientConnection); } - + long timeoutMs = getConnectionTimeout(); // 0 is used for testing to avoid issues with clock resolution / thread scheduling, // and force an immediate timeout. @@ -244,7 +244,7 @@ public void run() { } else { timeout.run(); } - + clientThread.start(); } } catch (IOException ioe) { From 592cfeab9caeff955d115a1ca5014ede7d402907 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 26 Oct 2017 00:29:49 -0700 Subject: [PATCH 117/712] [SPARK-22308] Support alternative unit testing styles in external applications ## What changes were proposed in this pull request? Support unit tests of external code (i.e., applications that use spark) using scalatest that don't want to use FunSuite. SharedSparkContext already supports this, but SharedSQLContext does not. I've introduced SharedSparkSession as a parent to SharedSQLContext, written in a way that it does support all scalatest styles. ## How was this patch tested? There are three new unit test suites added that just test using FunSpec, FlatSpec, and WordSpec. Author: Nathan Kronenfeld Closes #19529 from nkronenfeld/alternative-style-tests-2. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 +++++ .../spark/sql/test/GenericFunSpecSuite.scala | 47 +++++ .../spark/sql/test/GenericWordSpecSuite.scala | 51 ++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++++-------- .../spark/sql/test/SharedSQLContext.scala | 84 +-------- .../spark/sql/test/SharedSparkSession.scala | 119 ++++++++++++ 8 files changed, 381 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..6179585a0d39a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..15139ee8b3047 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..b6548bf95fec8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} From 3073344a2551fb198d63f2114a519ab97904cb55 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 15:50:27 +0800 Subject: [PATCH 118/712] [SPARK-21840][CORE] Add trait that allows conf to be directly set in application. Currently SparkSubmit uses system properties to propagate configuration to applications. This makes it hard to implement features such as SPARK-11035, which would allow multiple applications to be started in the same JVM. The current code would cause the config data from multiple apps to get mixed up. This change introduces a new trait, currently internal to Spark, that allows the app configuration to be passed directly to the application, without having to use system properties. The current "call main() method" behavior is maintained as an implementation of this new trait. This will be useful to allow multiple cluster mode apps to be submitted from the same JVM. As part of this, SparkSubmit was modified to collect all configuration directly into a SparkConf instance. Most of the changes are to tests so they use SparkConf instead of an opaque map. Tested with existing and added unit tests. Author: Marcelo Vanzin Closes #19519 from vanzin/SPARK-21840. --- .../spark/deploy/SparkApplication.scala | 55 +++++ .../org/apache/spark/deploy/SparkSubmit.scala | 160 +++++++------ .../spark/deploy/SparkSubmitSuite.scala | 213 ++++++++++-------- .../rest/StandaloneRestSubmitSuite.scala | 4 +- 4 files changed, 257 insertions(+), 175 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala new file mode 100644 index 0000000000000..118b4605675b0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.lang.reflect.Modifier + +import org.apache.spark.SparkConf + +/** + * Entry point for a Spark application. Implementations must provide a no-argument constructor. + */ +private[spark] trait SparkApplication { + + def start(args: Array[String], conf: SparkConf): Unit + +} + +/** + * Implementation of SparkApplication that wraps a standard Java class with a "main" method. + * + * Configuration is propagated to the application via system properties, so running multiple + * of these in the same JVM may lead to undefined behavior due to configuration leaks. + */ +private[deploy] class JavaMainApplication(klass: Class[_]) extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + val mainMethod = klass.getMethod("main", new Array[String](0).getClass) + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } + + val sysProps = conf.getAll.toMap + sysProps.foreach { case (k, v) => + sys.props(k) = v + } + + mainMethod.invoke(null, args) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b7e6d0ea021a4..73b956ef3e470 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,7 +158,7 @@ object SparkSubmit extends CommandLineUtils with Logging { */ @tailrec private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { - val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args) def doRunMain(): Unit = { if (args.proxyUser != null) { @@ -167,7 +167,7 @@ object SparkSubmit extends CommandLineUtils with Logging { try { proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } }) } catch { @@ -185,7 +185,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } } } else { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } } @@ -235,11 +235,11 @@ object SparkSubmit extends CommandLineUtils with Logging { private[deploy] def prepareSubmitEnvironment( args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], Map[String, String], String) = { + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() - val sysProps = new HashMap[String, String]() + val sparkConf = new SparkConf() var childMainClass = "" // Set the cluster manager @@ -337,7 +337,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - val sparkConf = new SparkConf(false) args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) } val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() @@ -351,8 +350,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // for later use; e.g. in spark sql, the isolated class loader used to talk // to HiveMetastore will use these settings. They will be set as Java system // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) + sparkConf.set(KEYTAB, args.keytab) + sparkConf.set(PRINCIPAL, args.principal) UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } @@ -364,23 +363,24 @@ object SparkSubmit extends CommandLineUtils with Logging { args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + // This security manager will not need an auth secret, but set a dummy value in case + // spark.authenticate is enabled, otherwise an exception is thrown. + lazy val downloadConf = sparkConf.clone().set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + lazy val secMgr = new SecurityManager(downloadConf) + // In client mode, download remote files. var localPrimaryResource: String = null var localJars: String = null var localPyFiles: String = null if (deployMode == CLIENT) { - // This security manager will not need an auth secret, but set a dummy value in case - // spark.authenticate is enabled, otherwise an exception is thrown. - sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val secMgr = new SecurityManager(sparkConf) localPrimaryResource = Option(args.primaryResource).map { - downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localJars = Option(args.jars).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localPyFiles = Option(args.pyFiles).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull } @@ -409,7 +409,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (file.exists()) { file.toURI.toString } else { - downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(resource, targetDir, downloadConf, hadoopConf, secMgr) } case _ => uri.toString } @@ -449,7 +449,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.files = mergeFileLists(args.files, args.pyFiles) } if (localPyFiles != null) { - sysProps("spark.submit.pyFiles") = localPyFiles + sparkConf.set("spark.submit.pyFiles", localPyFiles) } } @@ -515,69 +515,69 @@ object SparkSubmit extends CommandLineUtils with Logging { } // Special flag to avoid deprecation warnings at the client - sysProps("SPARK_SUBMIT") = "true" + sys.props("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.submit.deployMode"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), + confKey = "spark.submit.deployMode"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraClassPath"), + confKey = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraJavaOptions"), + confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraLibraryPath"), + confKey = "spark.driver.extraLibraryPath"), // Propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.ivy"), OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, - CLUSTER, sysProp = "spark.jars.excludes"), + CLUSTER, confKey = "spark.jars.excludes"), // Yarn only - OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.instances"), - OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), - OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), - OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), + confKey = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.cores"), + confKey = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.memory"), + confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.cores.max"), + confKey = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + confKey = "spark.files"), + OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.cores"), + confKey = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), // An internal option used only for spark-shell to add user jars to repl's classloader, // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to // remote jars, so adding a new option to only specify local jars for spark-shell internally. - OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.repl.local.jars") ) // In client mode, launch the application main class directly @@ -610,24 +610,24 @@ object SparkSubmit extends CommandLineUtils with Logging { (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } + if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) } } } // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { - sysProps(UI_SHOW_CONSOLE_PROGRESS.key) = "true" + sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) } // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file if (!isYarnCluster && !args.isPython && !args.isR) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) + var jars = sparkConf.getOption("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } - sysProps.put("spark.jars", jars.mkString(",")) + sparkConf.set("spark.jars", jars.mkString(",")) } // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). @@ -653,12 +653,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Let YARN know it's a pyspark app, so it distributes needed libraries. if (clusterManager == YARN) { if (args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + sparkConf.set("spark.yarn.isPython", "true") } } if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { - setRMPrincipal(sysProps) + setRMPrincipal(sparkConf) } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -689,7 +689,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // Second argument is main class childArgs += (args.primaryResource, "") if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + sparkConf.set("spark.submit.pyFiles", args.pyFiles) } } else if (args.isR) { // Second argument is main class @@ -704,12 +704,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { - sysProps.getOrElseUpdate(k, v) + sparkConf.setIfMissing(k, v) } // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= "spark.driver.host" + sparkConf.remove("spark.driver.host") } // Resolve paths in certain spark properties @@ -721,15 +721,15 @@ object SparkSubmit extends CommandLineUtils with Logging { "spark.yarn.dist.jars") pathConfigs.foreach { config => // Replace old URIs with resolved URIs, if they exist - sysProps.get(config).foreach { oldValue => - sysProps(config) = Utils.resolveURIs(oldValue) + sparkConf.getOption(config).foreach { oldValue => + sparkConf.set(config, Utils.resolveURIs(oldValue)) } } // Resolve and format python file paths properly before adding them to the PYTHONPATH. // The resolving part is redundant in the case of --py-files, but necessary if the user // explicitly sets `spark.submit.pyFiles` in his/her default properties file. - sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + sparkConf.getOption("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { PythonRunner.formatPaths(resolvedPyFiles).mkString(",") @@ -739,22 +739,22 @@ object SparkSubmit extends CommandLineUtils with Logging { // locally. resolvedPyFiles } - sysProps("spark.submit.pyFiles") = formattedPyFiles + sparkConf.set("spark.submit.pyFiles", formattedPyFiles) } - (childArgs, childClasspath, sysProps, childMainClass) + (childArgs, childClasspath, sparkConf, childMainClass) } // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we // must trick it into thinking we're YARN. - private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" // scalastyle:off println printStream.println(s"Setting ${key} to ${shortUserName}") // scalastyle:off println - sysProps.put(key, shortUserName) + sparkConf.set(key, shortUserName) } /** @@ -766,7 +766,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def runMain( childArgs: Seq[String], childClasspath: Seq[String], - sysProps: Map[String, String], + sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { // scalastyle:off println @@ -774,14 +774,14 @@ object SparkSubmit extends CommandLineUtils with Logging { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") + printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } // scalastyle:on println val loader = - if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { new ChildFirstURLClassLoader(new Array[URL](0), Thread.currentThread.getContextClassLoader) } else { @@ -794,10 +794,6 @@ object SparkSubmit extends CommandLineUtils with Logging { addJarToClasspath(jar, loader) } - for ((key, value) <- sysProps) { - System.setProperty(key, value) - } - var mainClass: Class[_] = null try { @@ -823,14 +819,14 @@ object SparkSubmit extends CommandLineUtils with Logging { System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } - // SPARK-4170 - if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") - } - - val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - if (!Modifier.isStatic(mainMethod.getModifiers)) { - throw new IllegalStateException("The main method in the given main class must be static") + val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { + mainClass.newInstance().asInstanceOf[SparkApplication] + } else { + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + new JavaMainApplication(mainClass) } @tailrec @@ -844,7 +840,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } try { - mainMethod.invoke(null, childArgs.toArray) + app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => findCause(t) match { @@ -1271,4 +1267,4 @@ private case class OptionAssigner( clusterManager: Int, deployMode: Int, clOption: String = null, - sysProp: String = null) + confKey: String = null) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b52da4c0c8bc3..cfbf56fb8c369 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -176,10 +176,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") - sysProps("spark.submit.deployMode") should be ("client") + conf.get("spark.submit.deployMode") should be ("client") // Both cmd line and configuration are specified, cmdline option takes the priority val clArgs1 = Seq( @@ -190,10 +190,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") - sysProps1("spark.submit.deployMode") should be ("cluster") + conf1.get("spark.submit.deployMode") should be ("cluster") // Neither cmdline nor configuration are specified, client mode is the default choice val clArgs2 = Seq( @@ -204,9 +204,9 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") - sysProps2("spark.submit.deployMode") should be ("client") + conf2.get("spark.submit.deployMode") should be ("client") } test("handles YARN cluster mode") { @@ -227,7 +227,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -240,16 +240,16 @@ class SparkSubmitSuite classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.driver.memory") should be ("4g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.app.name") should be ("beauty") - sysProps("spark.ui.enabled") should be ("false") - sysProps("SPARK_SUBMIT") should be ("true") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.app.name") should be ("beauty") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -270,7 +270,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -278,17 +278,17 @@ class SparkSubmitSuite classpath(1) should endWith ("one.jar") classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.app.name") should be ("trill") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.executor.instances") should be ("6") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.yarn.dist.jars") should include + conf.get("spark.app.name") should be ("trill") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.executor.instances") should be ("6") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") - sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles standalone cluster mode") { @@ -316,7 +316,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -327,17 +327,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 9 - sysProps.keys should contain ("SPARK_SUBMIT") - sysProps.keys should contain ("spark.master") - sysProps.keys should contain ("spark.app.name") - sysProps.keys should contain ("spark.jars") - sysProps.keys should contain ("spark.driver.memory") - sysProps.keys should contain ("spark.driver.cores") - sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.ui.enabled") - sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") + + val confMap = conf.getAll.toMap + confMap.keys should contain ("spark.master") + confMap.keys should contain ("spark.app.name") + confMap.keys should contain ("spark.jars") + confMap.keys should contain ("spark.driver.memory") + confMap.keys should contain ("spark.driver.cores") + confMap.keys should contain ("spark.driver.supervise") + confMap.keys should contain ("spark.ui.enabled") + confMap.keys should contain ("spark.submit.deployMode") + conf.get("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -352,14 +353,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -374,14 +375,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { @@ -394,23 +395,26 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn") - sysProps("spark.submit.deployMode") should be ("cluster") + val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.master") should be ("yarn") + conf.get("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + // Unset from system properties since this config is defined in the root pom's test config. + sys.props -= UI_SHOW_CONSOLE_PROGRESS.key + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) - sysProps1(UI_SHOW_CONSOLE_PROGRESS.key) should be ("true") + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) - sysProps2.keys should not contain UI_SHOW_CONSOLE_PROGRESS.key + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } test("launch simple application with spark-submit") { @@ -585,11 +589,11 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) - sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be (Utils.resolveURIs(files)) + conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be (Utils.resolveURIs(files)) // Test files and archives (Yarn) val clArgs2 = Seq( @@ -600,11 +604,11 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) - sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) // Test python files val clArgs3 = Seq( @@ -615,12 +619,12 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) - sysProps3("spark.submit.pyFiles") should be ( + conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - sysProps3(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") - sysProps3(PYSPARK_PYTHON.key) should be ("python3.5") + conf3.get(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") + conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } test("resolves config paths correctly") { @@ -644,9 +648,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be(Utils.resolveURIs(files)) + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) @@ -661,9 +665,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 - sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) @@ -676,8 +680,8 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 - sysProps3("spark.submit.pyFiles") should be( + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files @@ -693,11 +697,9 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode - sysProps4("spark.submit.pyFiles") should be( - Utils.resolveURIs(remotePyFiles) - ) + conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } test("user classpath first in driver") { @@ -771,14 +773,14 @@ class SparkSubmitSuite jar2.toString) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.yarn.dist.jars").split(",").toSet should be + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) - sysProps("spark.yarn.dist.files").split(",").toSet should be + conf.get("spark.yarn.dist.files").split(",").toSet should be (Set(file1.toURI.toString, file2.toURI.toString)) - sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be + conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) - sysProps("spark.yarn.dist.archives").split(",").toSet should be + conf.get("spark.yarn.dist.archives").split(",").toSet should be (Set(archive1.toURI.toString, archive2.toURI.toString)) } @@ -897,18 +899,18 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. - sysProps("spark.yarn.dist.jars") should be (tmpJarPath) - sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") - sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.yarn.dist.jars") should be (tmpJarPath) + conf.get("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + conf.get("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") // Local repl jars should be a local path. - sysProps("spark.repl.local.jars") should (startWith("file:")) + conf.get("spark.repl.local.jars") should (startWith("file:")) // local py files should not be a URI format. - sysProps("spark.submit.pyFiles") should (startWith("/")) + conf.get("spark.submit.pyFiles") should (startWith("/")) } test("download remote resource if it is not supported by yarn service") { @@ -955,9 +957,9 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) - val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + val jars = conf.get("spark.yarn.dist.jars").split(",").toSet // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) @@ -996,6 +998,21 @@ class SparkSubmitSuite conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) conf.set("fs.s3a.impl.disable.cache", "true") } + + test("start SparkApplication without modifying system properties") { + val args = Array( + "--class", classOf[TestSparkApplication].getName(), + "--master", "local", + "--conf", "spark.test.hello=world", + "spark-internal", + "hello") + + val exception = intercept[SparkException] { + SparkSubmit.main(args) + } + + assert(exception.getMessage() === "hello") + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1115,3 +1132,17 @@ class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { override def open(path: Path): FSDataInputStream = super.open(local(path)) } + +class TestSparkApplication extends SparkApplication with Matchers { + + override def start(args: Array[String], conf: SparkConf): Unit = { + assert(args.size === 1) + assert(args(0) === "hello") + assert(conf.get("spark.test.hello") === "world") + assert(sys.props.get("spark.test.hello") === None) + + // This is how the test verifies the application was actually run. + throw new SparkException(args(0)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 70887dc5dd97a..490baf040491f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,9 +445,9 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( - mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } /** Return the response as a submit response, or fail with error otherwise. */ From a83d8d5adcb4e0061e43105767242ba9770dda96 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Oct 2017 20:54:36 +0900 Subject: [PATCH 119/712] [SPARK-17902][R] Revive stringsAsFactors option for collect() in SparkR ## What changes were proposed in this pull request? This PR proposes to revive `stringsAsFactors` option in collect API, which was mistakenly removed in https://github.com/apache/spark/commit/71a138cd0e0a14e8426f97877e3b52a562bbd02c. Simply, it casts `charactor` to `factor` if it meets the condition, `stringsAsFactors && is.character(vec)` in primitive type conversion. ## How was this patch tested? Unit test in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #19551 from HyukjinKwon/SPARK-17902. --- R/pkg/R/DataFrame.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 176bb3b8a8d0c..aaa3349d57506 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1191,6 +1191,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4382ef2ed4525..0c8118a7c73f3 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -499,6 +499,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, From 0e9a750a8d389b3a17834584d31c204c77c6970d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Oct 2017 11:05:16 -0500 Subject: [PATCH 120/712] [SPARK-20643][CORE] Add listener implementation to collect app state. The initial listener code is based on the existing JobProgressListener (and others), and tries to mimic their behavior as much as possible. The change also includes some minor code movement so that some types and methods from the initial history server code code can be reused. The code introduces a few mutable versions of public API types, used internally, to make it easier to update information without ugly copy methods, and also to make certain updates cheaper. Note the code here is not 100% correct. This is meant as a building ground for the UI integration in the next milestones. As different parts of the UI are ported, fixes will be made to the different parts of this code to account for the needed behavior. I also added annotations to API types so that Jackson is able to correctly deserialize options, sequences and maps that store primitive types. Author: Marcelo Vanzin Closes #19383 from vanzin/SPARK-20643. --- .../apache/spark/util/kvstore/KVTypeInfo.java | 2 + .../apache/spark/util/kvstore/LevelDB.java | 2 +- .../spark/status/api/v1/StageStatus.java | 3 +- .../deploy/history/FsHistoryProvider.scala | 37 +- .../apache/spark/deploy/history/config.scala | 6 - .../spark/status/AppStatusListener.scala | 531 ++++++++++++++ .../org/apache/spark/status/KVUtils.scala | 73 ++ .../org/apache/spark/status/LiveEntity.scala | 526 +++++++++++++ .../status/api/v1/AllStagesResource.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 11 +- .../org/apache/spark/status/storeTypes.scala | 98 +++ .../history/FsHistoryProviderSuite.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 690 ++++++++++++++++++ project/MimaExcludes.scala | 2 + 14 files changed, 1942 insertions(+), 45 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusListener.scala create mode 100644 core/src/main/scala/org/apache/spark/status/KVUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/status/LiveEntity.scala create mode 100644 core/src/main/scala/org/apache/spark/status/storeTypes.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index a2b077e4531ee..870b484f99068 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -46,6 +46,7 @@ public KVTypeInfo(Class type) throws Exception { KVIndex idx = f.getAnnotation(KVIndex.class); if (idx != null) { checkIndex(idx, indices); + f.setAccessible(true); indices.put(idx.value(), idx); f.setAccessible(true); accessors.put(idx.value(), new FieldAccessor(f)); @@ -58,6 +59,7 @@ public KVTypeInfo(Class type) throws Exception { checkIndex(idx, indices); Preconditions.checkArgument(m.getParameterTypes().length == 0, "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + m.setAccessible(true); indices.put(idx.value(), idx); m.setAccessible(true); accessors.put(idx.value(), new MethodAccessor(m)); diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index ff48b155fab31..4f9e10ca20066 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -76,7 +76,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { this.types = new ConcurrentHashMap<>(); Options options = new Options(); - options.createIfMissing(!path.exists()); + options.createIfMissing(true); this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); byte[] versionData = db().get(STORE_VERSION_KEY); diff --git a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java index 9dbb565aab707..40b5f627369d5 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java @@ -23,7 +23,8 @@ public enum StageStatus { ACTIVE, COMPLETE, FAILED, - PENDING; + PENDING, + SKIPPED; public static StageStatus fromString(String str) { return EnumUtil.parseIgnoreCase(StageStatus.class, str); diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 3889dd097ee59..cf97597b484d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -129,29 +130,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => val dbPath = new File(path, "listing.ldb") - - def openDB(): LevelDB = new LevelDB(dbPath, new KVStoreScalaSerializer()) + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, logDir.toString()) try { - val db = openDB() - val meta = db.getMetadata(classOf[KVStoreMetadata]) - - if (meta == null) { - db.setMetadata(new KVStoreMetadata(CURRENT_LISTING_VERSION, logDir)) - db - } else if (meta.version != CURRENT_LISTING_VERSION || !logDir.equals(meta.logDir)) { - logInfo("Detected mismatched config in existing DB, deleting...") - db.close() - Utils.deleteRecursively(dbPath) - openDB() - } else { - db - } + open(new File(path, "listing.ldb"), metadata) } catch { - case _: UnsupportedStoreVersionException => + case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") Utils.deleteRecursively(dbPath) - openDB() + open(new File(path, "listing.ldb"), metadata) } }.getOrElse(new InMemoryStore()) @@ -720,19 +707,7 @@ private[history] object FsHistoryProvider { private[history] val CURRENT_LISTING_VERSION = 1L } -/** - * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as - * the API serializer. - */ -private class KVStoreScalaSerializer extends KVStoreSerializer { - - mapper.registerModule(DefaultScalaModule) - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) - mapper.setDateFormat(v1.JacksonMessageWriter.makeISODateFormat) - -} - -private[history] case class KVStoreMetadata( +private[history] case class FsHistoryProviderMetadata( version: Long, logDir: String) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index fb9e997def0dd..52dedc1a2ed41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -19,16 +19,10 @@ package org.apache.spark.deploy.history import java.util.concurrent.TimeUnit -import scala.annotation.meta.getter - import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.util.kvstore.KVIndex private[spark] object config { - /** Use this to annotate constructor params to be used as KVStore indices. */ - type KVIndexParam = KVIndex @getter - val DEFAULT_LOG_DIR = "file:/tmp/spark-events" val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala new file mode 100644 index 0000000000000..f120685c941df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +/** + * A Spark listener that writes application information to a data store. The types written to the + * store are defined in the `storeTypes.scala` file and are based on the public REST API. + */ +private class AppStatusListener(kvstore: KVStore) extends SparkListener with Logging { + + private var sparkVersion = SPARK_VERSION + private var appInfo: v1.ApplicationInfo = null + private var coresPerTask: Int = 1 + + // Keep track of live entities, so that task metrics can be efficiently updated (without + // causing too many writes to the underlying store, and other expensive operations). + private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveJobs = new HashMap[Int, LiveJob]() + private val liveExecutors = new HashMap[String, LiveExecutor]() + private val liveTasks = new HashMap[Long, LiveTask]() + private val liveRDDs = new HashMap[Int, LiveRDD]() + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(version) => sparkVersion = version + case _ => + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + assert(event.appId.isDefined, "Application without IDs are not supported.") + + val attempt = new v1.ApplicationAttemptInfo( + event.appAttemptId, + new Date(event.time), + new Date(-1), + new Date(event.time), + -1L, + event.sparkUser, + false, + sparkVersion) + + appInfo = new v1.ApplicationInfo( + event.appId.get, + event.appName, + None, + None, + None, + None, + Seq(attempt)) + + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + val old = appInfo.attempts.head + val attempt = new v1.ApplicationAttemptInfo( + old.attemptId, + old.startTime, + new Date(event.time), + new Date(event.time), + event.time - old.startTime.getTime(), + old.sparkUser, + true, + old.appSparkVersion) + + appInfo = new v1.ApplicationInfo( + appInfo.id, + appInfo.name, + None, + None, + None, + None, + Seq(attempt)) + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { + // This needs to be an update in case an executor re-registers after the driver has + // marked it as "dead". + val exec = getOrCreateExecutor(event.executorId) + exec.host = event.executorInfo.executorHost + exec.isActive = true + exec.totalCores = event.executorInfo.totalCores + exec.maxTasks = event.executorInfo.totalCores / coresPerTask + exec.executorLogs = event.executorInfo.logUrlMap + update(exec) + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { + liveExecutors.remove(event.executorId).foreach { exec => + exec.isActive = false + update(exec) + } + } + + override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { + updateBlackListStatus(event.executorId, true) + } + + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { + updateBlackListStatus(event.executorId, false) + } + + override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { + updateNodeBlackList(event.hostId, true) + } + + override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { + updateNodeBlackList(event.hostId, false) + } + + private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { + liveExecutors.get(execId).foreach { exec => + exec.isBlacklisted = blacklisted + update(exec) + } + } + + private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + // Implicitly (un)blacklist every executor associated with the node. + liveExecutors.values.foreach { exec => + if (exec.hostname == host) { + exec.isBlacklisted = blacklisted + update(exec) + } + } + } + + override def onJobStart(event: SparkListenerJobStart): Unit = { + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. + // This may be an over-estimate because the job start event references all of the result + // stages' transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + val numTasks = { + val missingStages = event.stageInfos.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + + val lastStageInfo = event.stageInfos.lastOption + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + + val jobGroup = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + + val job = new LiveJob( + event.jobId, + lastStageName, + Some(new Date(event.time)), + event.stageIds, + jobGroup, + numTasks) + liveJobs.put(event.jobId, job) + update(job) + + event.stageInfos.foreach { stageInfo => + // A new job submission may re-use an existing stage, so this code needs to do an update + // instead of just a write. + val stage = getOrCreateStage(stageInfo) + stage.jobs :+= job + stage.jobIds += event.jobId + update(stage) + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveJobs.remove(event.jobId).foreach { job => + job.status = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case JobFailed(_) => JobExecutionStatus.FAILED + } + + job.completionTime = Some(new Date(event.time)) + update(job) + } + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val stage = getOrCreateStage(event.stageInfo) + stage.status = v1.StageStatus.ACTIVE + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + // Look at all active jobs to find the ones that mention this stage. + stage.jobs = liveJobs.values + .filter(_.stageIds.contains(event.stageInfo.stageId)) + .toSeq + stage.jobIds = stage.jobs.map(_.jobId).toSet + + stage.jobs.foreach { job => + job.completedStages = job.completedStages - event.stageInfo.stageId + job.activeStages += 1 + update(job) + } + + event.stageInfo.rddInfos.foreach { info => + if (info.storageLevel.isValid) { + update(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + } + } + + update(stage) + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + liveTasks.put(event.taskInfo.taskId, task) + update(task) + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + stage.activeTasks += 1 + stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + update(stage) + + stage.jobs.foreach { job => + job.activeTasks += 1 + update(job) + } + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + exec.activeTasks += 1 + exec.totalTasks += 1 + update(exec) + } + } + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = { + // Call update on the task so that the "getting result" time is written to the store; the + // value is part of the mutable TaskInfo state that the live entity already references. + liveTasks.get(event.taskInfo.taskId).foreach { task => + update(task) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + // TODO: can this really happen? + if (event.taskInfo == null) { + return + } + + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + val errorMessage = event.reason match { + case Success => + None + case k: TaskKilled => + Some(k.reason) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates + Some(e.toErrorString) + case e: TaskFailedReason => // All other failure cases + Some(e.toErrorString) + case other => + logInfo(s"Unhandled task end reason: $other") + None + } + task.errorMessage = errorMessage + val delta = task.updateMetrics(event.taskMetrics) + update(task) + delta + }.orNull + + val (completedDelta, failedDelta) = event.reason match { + case Success => + (1, 0) + case _ => + (0, 1) + } + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + if (metricsDelta != null) { + stage.metrics.update(metricsDelta) + } + stage.activeTasks -= 1 + stage.completedTasks += completedDelta + stage.failedTasks += failedDelta + update(stage) + + stage.jobs.foreach { job => + job.activeTasks -= 1 + job.completedTasks += completedDelta + job.failedTasks += failedDelta + update(job) + } + + val esummary = stage.executorSummary(event.taskInfo.executorId) + esummary.taskTime += event.taskInfo.duration + esummary.succeededTasks += completedDelta + esummary.failedTasks += failedDelta + if (metricsDelta != null) { + esummary.metrics.update(metricsDelta) + } + update(esummary) + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + + exec.activeTasks -= 1 + exec.completedTasks += completedDelta + exec.failedTasks += failedDelta + exec.totalDuration += event.taskInfo.duration + update(exec) + } + } + + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + stage.info = event.stageInfo + + // Because of SPARK-20205, old event logs may contain valid stages without a submission time + // in their start event. In those cases, we can only detect whether a stage was skipped by + // waiting until the completion event, at which point the field would have been set. + stage.status = event.stageInfo.failureReason match { + case Some(_) => v1.StageStatus.FAILED + case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE + case _ => v1.StageStatus.SKIPPED + } + update(stage) + + stage.jobs.foreach { job => + stage.status match { + case v1.StageStatus.COMPLETE => + job.completedStages += event.stageInfo.stageId + case v1.StageStatus.SKIPPED => + job.skippedStages += event.stageInfo.stageId + job.skippedTasks += event.stageInfo.numTasks + case _ => + job.failedStages += 1 + } + job.activeStages -= 1 + update(job) + } + + stage.executorSummaries.values.foreach(update) + update(stage) + } + } + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { + // This needs to set fields that are already set by onExecutorAdded because the driver is + // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. + val exec = getOrCreateExecutor(event.blockManagerId.executorId) + exec.hostPort = event.blockManagerId.hostPort + event.maxOnHeapMem.foreach { _ => + exec.totalOnHeap = event.maxOnHeapMem.get + exec.totalOffHeap = event.maxOffHeapMem.get + } + exec.isActive = true + exec.maxMemory = event.maxMem + update(exec) + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { + // Nothing to do here. Covered by onExecutorRemoved. + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { + liveRDDs.remove(event.rddId) + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => + liveTasks.get(taskId).foreach { task => + val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) + val delta = task.updateMetrics(metrics) + update(task) + + liveStages.get((sid, sAttempt)).foreach { stage => + stage.metrics.update(delta) + update(stage) + + val esummary = stage.executorSummary(event.execId) + esummary.metrics.update(delta) + update(esummary) + } + } + } + } + + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + event.blockUpdatedInfo.blockId match { + case block: RDDBlockId => updateRDDBlock(event, block) + case _ => // TODO: API only covers RDD storage. + } + } + + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val executorId = event.blockUpdatedInfo.blockManagerId.executorId + + // Whether values are being added to or removed from the existing accounting. + val storageLevel = event.blockUpdatedInfo.storageLevel + val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) + val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) + + // Function to apply a delta to a value, but ensure that it doesn't go negative. + def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) + + val updatedStorageLevel = if (storageLevel.isValid) { + Some(storageLevel.description) + } else { + None + } + + // We need information about the executor to update some memory accounting values in the + // RDD info, so read that beforehand. + val maybeExec = liveExecutors.get(executorId) + var rddBlocksDelta = 0 + + // Update the block entry in the RDD info, keeping track of the deltas above so that we + // can update the executor information too. + liveRDDs.get(block.rddId).foreach { rdd => + val partition = rdd.partition(block.name) + + val executors = if (updatedStorageLevel.isDefined) { + if (!partition.executors.contains(executorId)) { + rddBlocksDelta = 1 + } + partition.executors + executorId + } else { + rddBlocksDelta = -1 + partition.executors - executorId + } + + // Only update the partition if it's still stored in some executor, otherwise get rid of it. + if (executors.nonEmpty) { + if (updatedStorageLevel.isDefined) { + partition.storageLevel = updatedStorageLevel.get + } + partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) + partition.diskUsed = newValue(partition.diskUsed, diskDelta) + partition.executors = executors + } else { + rdd.removePartition(block.name) + } + + maybeExec.foreach { exec => + if (exec.rddBlocks + rddBlocksDelta > 0) { + val dist = rdd.distribution(exec) + dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) + dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = newValue(dist.diskUsed, diskDelta) + + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) + } else { + dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) + } + } + } else { + rdd.removeDistribution(exec) + } + } + + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + update(rdd) + } + + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.rddBlocks += rddBlocksDelta + if (exec.hasMemoryInfo || rddBlocksDelta != 0) { + update(exec) + } + } + } + + private def getOrCreateExecutor(executorId: String): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + } + + private def getOrCreateStage(info: StageInfo): LiveStage = { + val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + stage.info = info + stage + } + + private def update(entity: LiveEntity): Unit = { + entity.write(kvstore) + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala new file mode 100644 index 0000000000000..4638511944c61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File + +import scala.annotation.meta.getter +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.internal.Logging +import org.apache.spark.util.kvstore._ + +private[spark] object KVUtils extends Logging { + + /** Use this to annotate constructor params to be used as KVStore indices. */ + type KVIndexParam = KVIndex @getter + + /** + * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as + * the API serializer. + */ + private[spark] class KVStoreScalaSerializer extends KVStoreSerializer { + + mapper.registerModule(DefaultScalaModule) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + + } + + /** + * Open or create a LevelDB store. + * + * @param path Location of the store. + * @param metadata Metadata value to compare to the data in the store. If the store does not + * contain any metadata (e.g. it's a new store), this value is written as + * the store's metadata. + */ + def open[M: ClassTag](path: File, metadata: M): LevelDB = { + require(metadata != null, "Metadata is required.") + + val db = new LevelDB(path, new KVStoreScalaSerializer()) + val dbMeta = db.getMetadata(classTag[M].runtimeClass) + if (dbMeta == null) { + db.setMetadata(metadata) + } else if (dbMeta != metadata) { + db.close() + throw new MetadataMismatchException() + } + + db + } + + private[spark] class MetadataMismatchException extends Exception + +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala new file mode 100644 index 0000000000000..63fa36580bc7d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} +import org.apache.spark.status.api.v1 +import org.apache.spark.storage.RDDInfo +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.kvstore.KVStore + +/** + * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live + * entity uses one of these instances to keep track of their evolving state, and periodically + * flush an immutable view of the entity to the app state store. + */ +private[spark] abstract class LiveEntity { + + def write(store: KVStore): Unit = { + store.write(doUpdate()) + } + + /** + * Returns an updated view of entity data, to be stored in the status store, reflecting the + * latest information collected by the listener. + */ + protected def doUpdate(): Any + +} + +private class LiveJob( + val jobId: Int, + name: String, + submissionTime: Option[Date], + val stageIds: Seq[Int], + jobGroup: Option[String], + numTasks: Int) extends LiveEntity { + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var skippedTasks = 0 + var skippedStages = Set[Int]() + + var status = JobExecutionStatus.RUNNING + var completionTime: Option[Date] = None + + var completedStages: Set[Int] = Set() + var activeStages = 0 + var failedStages = 0 + + override protected def doUpdate(): Any = { + val info = new v1.JobData( + jobId, + name, + None, // description is always None? + submissionTime, + completionTime, + stageIds, + jobGroup, + status, + numTasks, + activeTasks, + completedTasks, + skippedTasks, + failedTasks, + activeStages, + completedStages.size, + skippedStages.size, + failedStages) + new JobDataWrapper(info, skippedStages) + } + +} + +private class LiveTask( + info: TaskInfo, + stageId: Int, + stageAttemptId: Int) extends LiveEntity { + + import LiveEntityHelpers._ + + private var recordedMetrics: v1.TaskMetrics = null + + var errorMessage: Option[String] = None + + /** + * Update the metrics for the task and return the difference between the previous and new + * values. + */ + def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { + if (metrics != null) { + val old = recordedMetrics + recordedMetrics = new v1.TaskMetrics( + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGCTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + new v1.InputMetrics( + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead), + new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten), + new v1.ShuffleReadMetrics( + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead), + new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten)) + if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + } else { + null + } + } + + /** + * Return a new TaskMetrics object containing the delta of the various fields of the given + * metrics objects. This is currently targeted at updating stage data, so it does not + * necessarily calculate deltas for all the fields. + */ + private def calculateMetricsDelta( + metrics: v1.TaskMetrics, + old: v1.TaskMetrics): v1.TaskMetrics = { + val shuffleWriteDelta = new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, + 0L, + metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) + + val shuffleReadDelta = new v1.ShuffleReadMetrics( + 0L, 0L, 0L, + metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk - + old.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) + + val inputDelta = new v1.InputMetrics( + metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) + + val outputDelta = new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) + + new v1.TaskMetrics( + 0L, 0L, + metrics.executorRunTime - old.executorRunTime, + metrics.executorCpuTime - old.executorCpuTime, + 0L, 0L, 0L, + metrics.memoryBytesSpilled - old.memoryBytesSpilled, + metrics.diskBytesSpilled - old.diskBytesSpilled, + inputDelta, + outputDelta, + shuffleReadDelta, + shuffleWriteDelta) + } + + override protected def doUpdate(): Any = { + val task = new v1.TaskData( + info.taskId, + info.index, + info.attemptNumber, + new Date(info.launchTime), + if (info.finished) Some(info.duration) else None, + info.executorId, + info.host, + info.status, + info.taskLocality.toString(), + info.speculative, + newAccumulatorInfos(info.accumulables), + errorMessage, + Option(recordedMetrics)) + new TaskDataWrapper(task) + } + +} + +private class LiveExecutor(val executorId: String) extends LiveEntity { + + var hostPort: String = null + var host: String = null + var isActive = true + var totalCores = 0 + + var rddBlocks = 0 + var memoryUsed = 0L + var diskUsed = 0L + var maxTasks = 0 + var maxMemory = 0L + + var totalTasks = 0 + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + var totalDuration = 0L + var totalGcTime = 0L + var totalInputBytes = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + var isBlacklisted = false + + var executorLogs = Map[String, String]() + + // Memory metrics. They may not be recorded (e.g. old event logs) so if totalOnHeap is not + // initialized, the store will not contain this information. + var totalOnHeap = -1L + var totalOffHeap = 0L + var usedOnHeap = 0L + var usedOffHeap = 0L + + def hasMemoryInfo: Boolean = totalOnHeap >= 0L + + def hostname: String = if (host != null) host else hostPort.split(":")(0) + + override protected def doUpdate(): Any = { + val memoryMetrics = if (totalOnHeap >= 0) { + Some(new v1.MemoryMetrics(usedOnHeap, usedOffHeap, totalOnHeap, totalOffHeap)) + } else { + None + } + + val info = new v1.ExecutorSummary( + executorId, + if (hostPort != null) hostPort else host, + isActive, + rddBlocks, + memoryUsed, + diskUsed, + totalCores, + maxTasks, + activeTasks, + failedTasks, + completedTasks, + totalTasks, + totalDuration, + totalGcTime, + totalInputBytes, + totalShuffleRead, + totalShuffleWrite, + isBlacklisted, + maxMemory, + executorLogs, + memoryMetrics) + new ExecutorSummaryWrapper(info) + } + +} + +/** Metrics tracked per stage (both total and per executor). */ +private class MetricsTracker { + var executorRunTime = 0L + var executorCpuTime = 0L + var inputBytes = 0L + var inputRecords = 0L + var outputBytes = 0L + var outputRecords = 0L + var shuffleReadBytes = 0L + var shuffleReadRecords = 0L + var shuffleWriteBytes = 0L + var shuffleWriteRecords = 0L + var memoryBytesSpilled = 0L + var diskBytesSpilled = 0L + + def update(delta: v1.TaskMetrics): Unit = { + executorRunTime += delta.executorRunTime + executorCpuTime += delta.executorCpuTime + inputBytes += delta.inputMetrics.bytesRead + inputRecords += delta.inputMetrics.recordsRead + outputBytes += delta.outputMetrics.bytesWritten + outputRecords += delta.outputMetrics.recordsWritten + shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + + delta.shuffleReadMetrics.remoteBytesRead + shuffleReadRecords += delta.shuffleReadMetrics.recordsRead + shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten + shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten + memoryBytesSpilled += delta.memoryBytesSpilled + diskBytesSpilled += delta.diskBytesSpilled + } + +} + +private class LiveExecutorStageSummary( + stageId: Int, + attemptId: Int, + executorId: String) extends LiveEntity { + + var taskTime = 0L + var succeededTasks = 0 + var failedTasks = 0 + var killedTasks = 0 + + val metrics = new MetricsTracker() + + override protected def doUpdate(): Any = { + val info = new v1.ExecutorStageSummary( + taskTime, + failedTasks, + succeededTasks, + metrics.inputBytes, + metrics.outputBytes, + metrics.shuffleReadBytes, + metrics.shuffleWriteBytes, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled) + new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) + } + +} + +private class LiveStage extends LiveEntity { + + import LiveEntityHelpers._ + + var jobs = Seq[LiveJob]() + var jobIds = Set[Int]() + + var info: StageInfo = null + var status = v1.StageStatus.PENDING + + var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var firstLaunchTime = Long.MaxValue + + val metrics = new MetricsTracker() + + val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + + def executorSummary(executorId: String): LiveExecutorStageSummary = { + executorSummaries.getOrElseUpdate(executorId, + new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + } + + override protected def doUpdate(): Any = { + val update = new v1.StageData( + status, + info.stageId, + info.attemptId, + + activeTasks, + completedTasks, + failedTasks, + + metrics.executorRunTime, + metrics.executorCpuTime, + info.submissionTime.map(new Date(_)), + if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, + info.completionTime.map(new Date(_)), + + metrics.inputBytes, + metrics.inputRecords, + metrics.outputBytes, + metrics.outputRecords, + metrics.shuffleReadBytes, + metrics.shuffleReadRecords, + metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + + info.name, + info.details, + schedulingPool, + + newAccumulatorInfos(info.accumulables.values), + None, + None) + + new StageDataWrapper(update, jobIds) + } + +} + +private class LiveRDDPartition(val blockName: String) { + + var executors = Set[String]() + var storageLevel: String = null + var memoryUsed = 0L + var diskUsed = 0L + + def toApi(): v1.RDDPartitionInfo = { + new v1.RDDPartitionInfo( + blockName, + storageLevel, + memoryUsed, + diskUsed, + executors.toSeq.sorted) + } + +} + +private class LiveRDDDistribution(val exec: LiveExecutor) { + + var memoryRemaining = exec.maxMemory + var memoryUsed = 0L + var diskUsed = 0L + + var onHeapUsed = 0L + var offHeapUsed = 0L + var onHeapRemaining = 0L + var offHeapRemaining = 0L + + def toApi(): v1.RDDDataDistribution = { + new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + memoryRemaining, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, + if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + } + +} + +private class LiveRDD(info: RDDInfo) extends LiveEntity { + + var storageLevel: String = info.storageLevel.description + var memoryUsed = 0L + var diskUsed = 0L + + private val partitions = new HashMap[String, LiveRDDPartition]() + private val distributions = new HashMap[String, LiveRDDDistribution]() + + def partition(blockName: String): LiveRDDPartition = { + partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + } + + def removePartition(blockName: String): Unit = partitions.remove(blockName) + + def distribution(exec: LiveExecutor): LiveRDDDistribution = { + distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + } + + def removeDistribution(exec: LiveExecutor): Unit = { + distributions.remove(exec.hostPort) + } + + override protected def doUpdate(): Any = { + val parts = if (partitions.nonEmpty) { + Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) + } else { + None + } + + val dists = if (distributions.nonEmpty) { + Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + } else { + None + } + + val rdd = new v1.RDDStorageInfo( + info.id, + info.name, + info.numPartitions, + partitions.size, + storageLevel, + memoryUsed, + diskUsed, + dists, + parts) + + new RDDStorageInfoWrapper(rdd) + } + +} + +private object LiveEntityHelpers { + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { + accums + .filter { acc => + // We don't need to store internal or SQL accumulables as their values will be shown in + // other places, so drop them to reduce the memory usage. + !acc.internal && (!acc.metadata.isDefined || + acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + } + .map { acc => + new v1.AccumulableInfo( + acc.id, + acc.name.map(_.intern()).orNull, + acc.update.map(_.toString()), + acc.value.map(_.toString()).orNull) + } + .toSeq + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 4a4ed954d689e..5f69949c618fd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -71,7 +71,7 @@ private[v1] object AllStagesResource { val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }) + k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) } else { None } @@ -88,7 +88,7 @@ private[v1] object AllStagesResource { memoryBytesSpilled = summary.memoryBytesSpilled, diskBytesSpilled = summary.diskBytesSpilled ) - }) + }.toMap) } else { None } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 31659b25db318..bff6f90823f40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -16,11 +16,11 @@ */ package org.apache.spark.status.api.v1 +import java.lang.{Long => JLong} import java.util.Date -import scala.collection.Map - import com.fasterxml.jackson.annotation.JsonIgnoreProperties +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus @@ -129,9 +129,13 @@ class RDDDataDistribution private[spark]( val memoryUsed: Long, val memoryRemaining: Long, val diskUsed: Long, + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryRemaining: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( @@ -179,7 +183,8 @@ class TaskData private[spark]( val index: Int, val attempt: Int, val launchTime: Date, - val duration: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[JLong]) + val duration: Option[Long], val executorId: String, val host: String, val status: String, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala new file mode 100644 index 0000000000000..9579accd2cba7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import com.fasterxml.jackson.annotation.JsonIgnore + +import org.apache.spark.status.KVUtils._ +import org.apache.spark.status.api.v1._ +import org.apache.spark.util.kvstore.KVIndex + +private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { + + @JsonIgnore @KVIndex + def id: String = info.id + +} + +private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { + + @JsonIgnore @KVIndex + private[this] val id: String = info.id + + @JsonIgnore @KVIndex("active") + private[this] val active: Boolean = info.isActive + + @JsonIgnore @KVIndex("host") + val host: String = info.hostPort.split(":")(0) + +} + +/** + * Keep track of the existing stages when the job was submitted, and those that were + * completed during the job's execution. This allows a more accurate acounting of how + * many tasks were skipped for the job. + */ +private[spark] class JobDataWrapper( + val info: JobData, + val skippedStages: Set[Int]) { + + @JsonIgnore @KVIndex + private[this] val id: Int = info.jobId + +} + +private[spark] class StageDataWrapper( + val info: StageData, + val jobIds: Set[Int]) { + + @JsonIgnore @KVIndex + def id: Array[Int] = Array(info.stageId, info.attemptId) + +} + +private[spark] class TaskDataWrapper(val info: TaskData) { + + @JsonIgnore @KVIndex + def id: Long = info.taskId + +} + +private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { + + @JsonIgnore @KVIndex + def id: Int = info.id + + @JsonIgnore @KVIndex("cached") + def cached: Boolean = info.numCachedPartitions > 0 + +} + +private[spark] class ExecutorStageSummaryWrapper( + val stageId: Int, + val stageAttemptId: Int, + val executorId: String, + val info: ExecutorStageSummary) { + + @JsonIgnore @KVIndex + val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + + @JsonIgnore @KVIndex("stage") + private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2141934c92640..03bd3eaf579f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -611,7 +611,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Manually overwrite the version in the listing db; this should cause the new provider to // discard all data because the versions don't match. - val meta = new KVStoreMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, + val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, conf.get(LOCAL_STORE_DIR).get) oldProvider.listing.setMetadata(meta) oldProvider.stop() diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala new file mode 100644 index 0000000000000..6f7a0c14dd684 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -0,0 +1,690 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Date, Properties} + +import scala.collection.JavaConverters._ +import scala.reflect.{classTag, ClassTag} + +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore._ + +class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + + private var time: Long = _ + private var testDir: File = _ + private var store: KVStore = _ + + before { + time = 0L + testDir = Utils.createTempDir() + store = KVUtils.open(testDir, getClass().getName()) + } + + after { + store.close() + Utils.deleteRecursively(testDir) + } + + test("scheduler events") { + val listener = new AppStatusListener(store) + + // Start the application. + time += 1 + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + None)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(time)) + assert(attempt.lastUpdated === new Date(time)) + assert(attempt.endTime.getTime() === -1L) + assert(attempt.sparkUser === "user") + assert(!attempt.completed) + } + + // Start a couple of executors. + time += 1 + val execIds = Array("1", "2") + + execIds.foreach { id => + listener.onExecutorAdded(SparkListenerExecutorAdded(time, id, + new ExecutorInfo(s"$id.example.com", 1, Map()))) + } + + execIds.foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(exec.info.hostPort === s"$id.example.com") + assert(exec.info.isActive) + } + } + + // Start a job with 2 stages / 4 tasks each + time += 1 + val stages = Seq( + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) + + val jobProps = new Properties() + jobProps.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "jobGroup") + jobProps.setProperty("spark.scheduler.pool", "schedPool") + + listener.onJobStart(SparkListenerJobStart(1, time, stages, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.jobId === 1) + assert(job.info.name === stages.last.name) + assert(job.info.description === None) + assert(job.info.status === JobExecutionStatus.RUNNING) + assert(job.info.submissionTime === Some(new Date(time))) + assert(job.info.jobGroup === Some("jobGroup")) + } + + stages.foreach { info => + check[StageDataWrapper](key(info)) { stage => + assert(stage.info.status === v1.StageStatus.PENDING) + assert(stage.jobIds === Set(1)) + } + } + + // Submit stage 1 + time += 1 + stages.head.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.head, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) + assert(stage.info.schedulingPool === "schedPool") + } + + // Start tasks from stage 1 + time += 1 + var _taskIdTracker = -1L + def nextTaskId(): Long = { + _taskIdTracker += 1 + _taskIdTracker + } + + def createTasks(count: Int, time: Long): Seq[TaskInfo] = { + (1 to count).map { id => + val exec = execIds(id.toInt % execIds.length) + val taskId = nextTaskId() + new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", + TaskLocality.PROCESS_LOCAL, id % 2 == 0) + } + } + + val s1Tasks = createTasks(4, time) + s1Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + } + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numActiveTasks === s1Tasks.size) + assert(stage.info.firstTaskLaunchedTime === Some(new Date(s1Tasks.head.launchTime))) + } + + s1Tasks.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.taskId === task.taskId) + assert(wrapper.info.index === task.index) + assert(wrapper.info.attempt === task.attemptNumber) + assert(wrapper.info.launchTime === new Date(task.launchTime)) + assert(wrapper.info.executorId === task.executorId) + assert(wrapper.info.host === task.host) + assert(wrapper.info.status === task.status) + assert(wrapper.info.taskLocality === task.taskLocality.toString()) + assert(wrapper.info.speculative === task.speculative) + } + } + + // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(1L), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size) + } + + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + } + + // Fail one of the tasks, re-start it. + time += 1 + s1Tasks.head.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskResultLost, s1Tasks.head, null)) + + time += 1 + val reattempt = { + val orig = s1Tasks.head + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt)) + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === s1Tasks.size) + } + + check[TaskDataWrapper](s1Tasks.head.taskId) { task => + assert(task.info.status === s1Tasks.head.status) + assert(task.info.duration === Some(s1Tasks.head.duration)) + assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + } + + check[TaskDataWrapper](reattempt.taskId) { task => + assert(task.info.index === s1Tasks.head.index) + assert(task.info.attempt === reattempt.attemptNumber) + } + + // Succeed all tasks in stage 1. + val pending = s1Tasks.drop(1) ++ Seq(reattempt) + + val s1Metrics = TaskMetrics.empty + s1Metrics.setExecutorCpuTime(2L) + s1Metrics.setExecutorRunTime(4L) + + time += 1 + pending.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", Success, task, s1Metrics)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === 0) + assert(job.info.numCompletedTasks === pending.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + pending.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.errorMessage === None) + assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) + assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + } + } + + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + + // End stage 1. + time += 1 + stages.head.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stages.head)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numCompletedStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + // Submit stage 2. + time += 1 + stages.last.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.last, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) + } + + // Start and fail all tasks of stage 2. + time += 1 + val s2Tasks = createTasks(4, time) + s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + } + + time += 1 + s2Tasks.foreach { task => + task.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + "taskType", TaskResultLost, task, null)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1 + s2Tasks.size) + assert(job.info.numActiveTasks === 0) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + } + + // Fail stage 2. + time += 1 + stages.last.completionTime = Some(time) + stages.last.failureReason = Some("uh oh") + listener.onStageCompleted(SparkListenerStageCompleted(stages.last)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 1) + assert(job.info.numFailedStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.FAILED) + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === 0) + } + + // - Re-submit stage 2, all tasks, and succeed them and the stage. + val oldS2 = stages.last + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) + + time += 1 + newS2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(newS2, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 3) + + val newS2Tasks = createTasks(4, time) + + newS2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + } + + time += 1 + newS2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, + task, null)) + } + + time += 1 + newS2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(newS2)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numFailedStages === 1) + assert(job.info.numCompletedStages === 2) + } + + check[StageDataWrapper](key(newS2)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === newS2Tasks.size) + } + + // End job. + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + } + + // Submit a second job that re-uses stage 1 and stage 2. Stage 1 won't be re-run, but + // stage 2 will. In any case, the DAGScheduler creates new info structures that are copies + // of the old stages, so mimic that behavior here. The "new" stage 1 is submitted without + // a submission time, which means it is "skipped", and the stage 2 re-execution should not + // change the stats of the already finished job. + time += 1 + val j2Stages = Seq( + new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2")) + j2Stages.last.submissionTime = Some(time) + listener.onJobStart(SparkListenerJobStart(2, time, j2Stages, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.head, jobProps)) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.head)) + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.last, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 5) + + time += 1 + val j2s2Tasks = createTasks(4, time) + + j2s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + task)) + } + + time += 1 + j2s2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + "taskType", Success, task, null)) + } + + time += 1 + j2Stages.last.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.last)) + + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 2) + assert(job.info.numCompletedTasks === s1Tasks.size + s2Tasks.size) + } + + check[JobDataWrapper](2) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + assert(job.info.numCompletedStages === 1) + assert(job.info.numCompletedTasks === j2s2Tasks.size) + assert(job.info.numSkippedStages === 1) + assert(job.info.numSkippedTasks === s1Tasks.size) + } + + // Blacklist an executor. + time += 1 + listener.onExecutorBlacklisted(SparkListenerExecutorBlacklisted(time, "1", 42)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted(time, "1")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Blacklist a node. + time += 1 + listener.onNodeBlacklisted(SparkListenerNodeBlacklisted(time, "1.example.com", 2)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onNodeUnblacklisted(SparkListenerNodeUnblacklisted(time, "1.example.com")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Stop executors. + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "1", "Test")) + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "2", "Test")) + + Seq("1", "2").foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(!exec.info.isActive) + } + } + + // End the application. + listener.onApplicationEnd(SparkListenerApplicationEnd(42L)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(1L)) + assert(attempt.lastUpdated === new Date(42L)) + assert(attempt.endTime === new Date(42L)) + assert(attempt.duration === 41L) + assert(attempt.sparkUser === "user") + assert(attempt.completed) + } + } + + test("storage events") { + val listener = new AppStatusListener(store) + val maxMemory = 42L + + // Register a couple of block managers. + val bm1 = BlockManagerId("1", "1.example.com", 42) + val bm2 = BlockManagerId("2", "2.example.com", 84) + Seq(bm1, bm2).foreach { bm => + listener.onExecutorAdded(SparkListenerExecutorAdded(1L, bm.executorId, + new ExecutorInfo(bm.host, 1, Map()))) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm, maxMemory)) + check[ExecutorSummaryWrapper](bm.executorId) { exec => + assert(exec.info.maxMemory === maxMemory) + } + } + + val rdd1b1 = RDDBlockId(1, 1) + val level = StorageLevel.MEMORY_AND_DISK + + // Submit a stage and make sure the RDD is recorded. + val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.name === rddInfo.name) + assert(wrapper.info.numPartitions === rddInfo.numPartitions) + assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + } + + // Add partition 1 replicated on two block managers. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 1L) + assert(wrapper.info.diskUsed === 1L) + + assert(wrapper.info.dataDistribution.isDefined) + assert(wrapper.info.dataDistribution.get.size === 1) + + val dist = wrapper.info.dataDistribution.get.head + assert(dist.address === bm1.hostPort) + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + assert(wrapper.info.partitions.isDefined) + assert(wrapper.info.partitions.get.size === 1) + + val part = wrapper.info.partitions.get.head + assert(part.blockName === rdd1b1.name) + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 2L) + assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get(0) + assert(part.memoryUsed === 2L) + assert(part.diskUsed === 2L) + assert(part.executors === Seq(bm1.executorId, bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + // Add a second partition only to bm 1. + val rdd1b2 = RDDBlockId(1, 2) + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, + 3L, 3L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 5L) + assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 4L) + assert(dist.diskUsed === 4L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 3L) + assert(part.diskUsed === 3L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === 4L) + assert(exec.info.diskUsed === 4L) + } + + // Remove block 1 from bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 4L) + assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 3L) + assert(dist.diskUsed === 3L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 3L) + assert(exec.info.diskUsed === 3L) + } + + // Remove block 2 from bm 2. This should leave only block 2 info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 3L) + assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.dataDistribution.get.size === 1L) + assert(wrapper.info.partitions.get.size === 1L) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0L) + assert(exec.info.diskUsed === 0L) + } + + // Unpersist RDD1. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } + } + + } + + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + + private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { + val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] + fn(value) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535e..45b8870f3b62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-18085: Better History Server scalability for many / large applications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 4f8dc6b01ea787243a38678ea8199fbb0814cffc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 21:41:45 +0100 Subject: [PATCH 121/712] [SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields ## What changes were proposed in this pull request? When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure them and don't set it properly. Those fields will be in null values. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19556 from viirya/SPARK-22328. --- .../apache/spark/util/ClosureCleaner.scala | 73 ++++++++++++++++--- .../spark/util/ClosureCleanerSuite.scala | 72 ++++++++++++++++++ 2 files changed, 133 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 48a1d7b84b61b..dfece5dd0670b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging { (seen - obj.getClass).toList } + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() + } + } + } + + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) + } + } + + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) + + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } + + clone + } + /** * Clean the given closure in place. * @@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + populating accessed fields because this is the starting closure") // Initialize accessed fields with the outer classes first // This step is needed to associate the fields to the correct classes later - for (cls <- outerClasses) { - accessedFields(cls) = Set.empty[String] - } + initAccessedFields(accessedFields, outerClasses) + // Populate accessed fields by visiting all fields and methods accessed by this and // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. @@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - field.set(clone, value) - } + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { @@ -395,8 +437,15 @@ private[util] class FieldAccessFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - ClosureCleaner.getClassReader(cl).accept( - new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 4920b7ee8bfb4..9a19baee9569e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite { test("createNullValue") { new TestCreateNullValue().run() } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + val concreteObject = new TestAbstractClass { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] = { + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1) + body(rdd) + } + } + + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map { _ => + (n1, n2, s1, s2, d1, d2) + }.collect() + } + assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + + test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + val concreteObject = new TestAbstractClass2 { + val n2 = 222 + val s2 = "bbb" + val d2 = 2.0d + def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2, s1, s2, d1, d2) + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) + assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d))) + } + } + + test("SPARK-22328: multiple outer classes have the same parent class") { + val concreteObject = new TestAbstractClass2 { + + val innerObject = new TestAbstractClass2 { + override val n1 = 222 + override val s1 = "bbb" + } + + val innerObject2 = new TestAbstractClass2 { + override val n1 = 444 + val n3 = 333 + val s3 = "ccc" + val d3 = 3.0d + + def getData: Int => (Int, Int, String, String, Double, Double, Int, String) = + _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1) + } + } + withSpark(new SparkContext("local", "test")) { sc => + val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) + assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb"))) + } + } } // A non-serializable class we create in closures to make sure that we aren't @@ -377,3 +434,18 @@ class TestCreateNullValue { nestedClosure() } } + +abstract class TestAbstractClass extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d + + def run(): Seq[(Int, Int, String, String, Double, Double)] + def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] +} + +abstract class TestAbstractClass2 extends Serializable { + val n1 = 111 + val s1 = "aaa" + protected val d1 = 1.0d +} From 5415963d2caaf95604211419ffc4e29fff38e1d7 Mon Sep 17 00:00:00 2001 From: "Susan X. Huynh" Date: Thu, 26 Oct 2017 16:13:48 -0700 Subject: [PATCH 122/712] [SPARK-22131][MESOS] Mesos driver secrets ## Background In #18837 , ArtRand added Mesos secrets support to the dispatcher. **This PR is to add the same secrets support to the drivers.** This means if the secret configs are set, the driver will launch executors that have access to either env or file-based secrets. One use case for this is to support TLS in the driver <=> executor communication. ## What changes were proposed in this pull request? Most of the changes are a refactor of the dispatcher secrets support (#18837) - moving it to a common place that can be used by both the dispatcher and drivers. The same goes for the unit tests. ## How was this patch tested? There are four config combinations: [env or file-based] x [value or reference secret]. For each combination: - Added a unit test. - Tested in DC/OS. Author: Susan X. Huynh Closes #19437 from susanxhuynh/sh-mesos-driver-secret. --- docs/running-on-mesos.md | 111 ++++++++++--- .../apache/spark/deploy/mesos/config.scala | 64 ++++---- .../cluster/mesos/MesosClusterScheduler.scala | 138 +++------------- .../MesosCoarseGrainedSchedulerBackend.scala | 31 +++- .../MesosFineGrainedSchedulerBackend.scala | 4 +- .../mesos/MesosSchedulerBackendUtil.scala | 92 ++++++++++- .../mesos/MesosClusterSchedulerSuite.scala | 150 +++--------------- ...osCoarseGrainedSchedulerBackendSuite.scala | 34 +++- .../MesosSchedulerBackendUtilSuite.scala | 7 +- .../spark/scheduler/cluster/mesos/Utils.scala | 107 +++++++++++++ 10 files changed, 434 insertions(+), 304 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index e0944bc9f5f86..b7e3e6473c338 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -485,39 +485,106 @@ See the [configuration page](configuration.html) for information on Spark config - spark.mesos.driver.secret.envkeys - (none) - A comma-separated list that, if set, the contents of the secret referenced - by spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - set to the provided environment variable in the driver's process. + spark.mesos.driver.secret.values, + spark.mesos.driver.secret.names, + spark.mesos.executor.secret.values, + spark.mesos.executor.secret.names, - - -spark.mesos.driver.secret.filenames (none) - A comma-separated list that, if set, the contents of the secret referenced by - spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be - written to the provided file. Paths are relative to the container's work - directory. Absolute paths must already exist. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's contents. To specify a secret's destination, see the cell below. +

    +

    + You can specify a secret's contents either (1) by value or (2) by reference. +

    +

    + (1) To specify a secret by value, set the + spark.mesos.[driver|executor].secret.values + property, to make the secret available in the driver or executors. + For example, to make a secret password "guessme" available to the driver process, set: + +

    spark.mesos.driver.secret.values=guessme
    +

    +

    + (2) To specify a secret that has been placed in a secret store + by reference, specify its name within the secret store + by setting the spark.mesos.[driver|executor].secret.names + property. For example, to make a secret password named "password" in a secret store + available to the driver process, set: + +

    spark.mesos.driver.secret.names=password
    +

    +

    + Note: To use a secret store, make sure one has been integrated with Mesos via a custom + SecretResolver + module. +

    +

    + To specify multiple secrets, provide a comma-separated list: + +

    spark.mesos.driver.secret.values=guessme,passwd123
    + + or + +
    spark.mesos.driver.secret.names=password1,password2
    +

    + - spark.mesos.driver.secret.names - (none) - A comma-separated list of secret references. Consult the Mesos Secret - protobuf for more information. + spark.mesos.driver.secret.envkeys, + spark.mesos.driver.secret.filenames, + spark.mesos.executor.secret.envkeys, + spark.mesos.executor.secret.filenames, - - - spark.mesos.driver.secret.values (none) - A comma-separated list of secret values. Consult the Mesos Secret - protobuf for more information. +

    + A secret is specified by its contents and destination. These properties + specify a secret's destination. To specify a secret's contents, see the cell above. +

    +

    + You can specify a secret's destination in the driver or + executors as either (1) an environment variable or (2) as a file. +

    +

    + (1) To make an environment-based secret, set the + spark.mesos.[driver|executor].secret.envkeys property. + The secret will appear as an environment variable with the + given name in the driver or executors. For example, to make a secret password available + to the driver process as $PASSWORD, set: + +

    spark.mesos.driver.secret.envkeys=PASSWORD
    +

    +

    + (2) To make a file-based secret, set the + spark.mesos.[driver|executor].secret.filenames property. + The secret will appear in the contents of a file with the given file name in + the driver or executors. For example, to make a secret password available in a + file named "pwdfile" in the driver process, set: + +

    spark.mesos.driver.secret.filenames=pwdfile
    +

    +

    + Paths are relative to the container's work directory. Absolute paths must + already exist. Note: File-based secrets require a custom + SecretResolver + module. +

    +

    + To specify env vars or file names corresponding to multiple secrets, + provide a comma-separated list: + +

    spark.mesos.driver.secret.envkeys=PASSWORD1,PASSWORD2
    + + or + +
    spark.mesos.driver.secret.filenames=pwdfile1,pwdfile2
    +

    diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 7e85de91c5d36..821534eb4fc38 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -23,6 +23,39 @@ import org.apache.spark.internal.config.ConfigBuilder package object config { + private[spark] class MesosSecretConfig private[config](taskType: String) { + private[spark] val SECRET_NAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.names") + .doc("A comma-separated list of secret reference names. Consult the Mesos Secret " + + "protobuf for more information.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_VALUES = + ConfigBuilder(s"spark.mesos.$taskType.secret.values") + .doc("A comma-separated list of secret values.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_ENVKEYS = + ConfigBuilder(s"spark.mesos.$taskType.secret.envkeys") + .doc("A comma-separated list of the environment variables to contain the secrets." + + "The environment variable will be set on the driver.") + .stringConf + .toSequence + .createOptional + + private[spark] val SECRET_FILENAMES = + ConfigBuilder(s"spark.mesos.$taskType.secret.filenames") + .doc("A comma-separated list of file paths secret will be written to. Consult the Mesos " + + "Secret protobuf for more information.") + .stringConf + .toSequence + .createOptional + } + /* Common app configuration. */ private[spark] val SHUFFLE_CLEANER_INTERVAL_S = @@ -64,36 +97,9 @@ package object config { .stringConf .createOptional - private[spark] val SECRET_NAME = - ConfigBuilder("spark.mesos.driver.secret.names") - .doc("A comma-separated list of secret reference names. Consult the Mesos Secret protobuf " + - "for more information.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_VALUE = - ConfigBuilder("spark.mesos.driver.secret.values") - .doc("A comma-separated list of secret values.") - .stringConf - .toSequence - .createOptional + private[spark] val driverSecretConfig = new MesosSecretConfig("driver") - private[spark] val SECRET_ENVKEY = - ConfigBuilder("spark.mesos.driver.secret.envkeys") - .doc("A comma-separated list of the environment variables to contain the secrets." + - "The environment variable will be set on the driver.") - .stringConf - .toSequence - .createOptional - - private[spark] val SECRET_FILENAME = - ConfigBuilder("spark.mesos.driver.secret.filenames") - .doc("A comma-seperated list of file paths secret will be written to. Consult the Mesos " + - "Secret protobuf for more information.") - .stringConf - .toSequence - .createOptional + private[spark] val executorSecretConfig = new MesosSecretConfig("executor") private[spark] val DRIVER_FAILOVER_TIMEOUT = ConfigBuilder("spark.mesos.driver.failoverTimeout") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index ec533f91474f2..82470264f2a4a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -28,7 +28,6 @@ import org.apache.mesos.{Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription @@ -394,39 +393,20 @@ private[spark] class MesosClusterScheduler( } // add secret environment variables - getSecretEnvVar(desc).foreach { variable => - if (variable.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName}" + - s"on file ${variable.getName}") - } else { - logInfo(s"Setting secret on environment variable name=${variable.getName}") - } - envBuilder.addVariables(variable) + MesosSchedulerBackendUtil.getSecretEnvVar(desc.conf, config.driverSecretConfig) + .foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + envBuilder.addVariables(variable) } envBuilder.build() } - private def getSecretEnvVar(desc: MesosDriverDescription): List[Variable] = { - val secrets = getSecrets(desc) - val secretEnvKeys = desc.conf.get(config.SECRET_ENVKEY).getOrElse(Nil) - if (illegalSecretInput(secretEnvKeys, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and environment keys " + - s"for environment-based reference secrets got secrets $secrets, " + - s"and keys $secretEnvKeys") - } - - secrets.zip(secretEnvKeys).map { - case (s, k) => - Variable.newBuilder() - .setName(k) - .setType(Variable.Type.SECRET) - .setSecret(s) - .build - }.toList - } - private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -440,6 +420,23 @@ private[spark] class MesosClusterScheduler( CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) } + private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(desc.conf) + + MesosSchedulerBackendUtil.getSecretVolume(desc.conf, config.driverSecretConfig) + .foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + private def getDriverCommandValue(desc: MesosDriverDescription): String = { val dockerDefined = desc.conf.contains("spark.mesos.executor.docker.image") val executorUri = getDriverExecutorURI(desc) @@ -579,89 +576,6 @@ private[spark] class MesosClusterScheduler( .build } - private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { - val containerInfo = MesosSchedulerBackendUtil.containerInfo(desc.conf) - - getSecretVolume(desc).foreach { volume => - if (volume.getSource.getSecret.getReference.isInitialized) { - logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName}" + - s"on file ${volume.getContainerPath}") - } else { - logInfo(s"Setting secret on file name=${volume.getContainerPath}") - } - containerInfo.addVolumes(volume) - } - - containerInfo - } - - - private def getSecrets(desc: MesosDriverDescription): Seq[Secret] = { - def createValueSecret(data: String): Secret = { - Secret.newBuilder() - .setType(Secret.Type.VALUE) - .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) - .build() - } - - def createReferenceSecret(name: String): Secret = { - Secret.newBuilder() - .setReference(Secret.Reference.newBuilder().setName(name)) - .setType(Secret.Type.REFERENCE) - .build() - } - - val referenceSecrets: Seq[Secret] = - desc.conf.get(config.SECRET_NAME).getOrElse(Nil).map(s => createReferenceSecret(s)) - - val valueSecrets: Seq[Secret] = { - desc.conf.get(config.SECRET_VALUE).getOrElse(Nil).map(s => createValueSecret(s)) - } - - if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { - throw new SparkException("Cannot specify VALUE type secrets and REFERENCE types ones") - } - - if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets - } - - private def illegalSecretInput(dest: Seq[String], s: Seq[Secret]): Boolean = { - if (dest.isEmpty) { // no destination set (ie not using secrets of this type - return false - } - if (dest.nonEmpty && s.nonEmpty) { - // make sure there is a destination for each secret of this type - if (dest.length != s.length) { - return true - } - } - false - } - - private def getSecretVolume(desc: MesosDriverDescription): List[Volume] = { - val secrets = getSecrets(desc) - val secretPaths: Seq[String] = - desc.conf.get(config.SECRET_FILENAME).getOrElse(Nil) - - if (illegalSecretInput(secretPaths, secrets)) { - throw new SparkException( - s"Need to give equal numbers of secrets and file paths for file-based " + - s"reference secrets got secrets $secrets, and paths $secretPaths") - } - - secrets.zip(secretPaths).map { - case (s, p) => - val source = Volume.Source.newBuilder() - .setType(Volume.Source.Type.SECRET) - .setSecret(s) - Volume.newBuilder() - .setContainerPath(p) - .setSource(source) - .setMode(Volume.Mode.RO) - .build - }.toList - } - /** * This method takes all the possible candidates and attempt to schedule them with Mesos offers. * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 603c980cb268d..104ed01d293ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -28,7 +28,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config @@ -244,6 +244,17 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setValue(value) .build()) } + + MesosSchedulerBackendUtil.getSecretEnvVar(conf, executorSecretConfig).foreach { variable => + if (variable.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName} " + + s"on file ${variable.getName}") + } else { + logInfo(s"Setting secret on environment variable name=${variable.getName}") + } + environment.addVariables(variable) + } + val command = CommandInfo.newBuilder() .setEnvironment(environment) @@ -424,6 +435,22 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } } + private def getContainerInfo(conf: SparkConf): ContainerInfo.Builder = { + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo(conf) + + MesosSchedulerBackendUtil.getSecretVolume(conf, executorSecretConfig).foreach { volume => + if (volume.getSource.getSecret.getReference.isInitialized) { + logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName} " + + s"on file ${volume.getContainerPath}") + } else { + logInfo(s"Setting secret on file name=${volume.getContainerPath}") + } + containerInfo.addVolumes(volume) + } + + containerInfo + } + /** * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize * per-task memory and IO, tasks are round-robin assigned to offers. @@ -475,7 +502,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setName(s"${sc.appName} $taskId") .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) .addAllResources(resourcesToUse.asJava) - .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + .setContainer(getContainerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 66b8e0a640121..d6d939d246109 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -159,7 +160,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - executorInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + executorInfo.setContainer( + MesosSchedulerBackendUtil.buildContainerInfo(sc.conf)) (executorInfo.build(), resourcesAfterMem.asJava) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index f29e541addf23..bfb73611f0530 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,11 +17,15 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Environment, Image, NetworkInfo, Parameter, Secret, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.protobuf.ByteString -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkConf +import org.apache.spark.SparkException import org.apache.spark.deploy.mesos.config.{NETWORK_LABELS, NETWORK_NAME} +import org.apache.spark.deploy.mesos.config.MesosSecretConfig import org.apache.spark.internal.Logging /** @@ -122,7 +126,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } - def containerInfo(conf: SparkConf): ContainerInfo.Builder = { + def buildContainerInfo(conf: SparkConf): ContainerInfo.Builder = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { ContainerInfo.Type.DOCKER @@ -173,6 +177,88 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { containerInfo } + private def getSecrets(conf: SparkConf, secretConfig: MesosSecretConfig): Seq[Secret] = { + def createValueSecret(data: String): Secret = { + Secret.newBuilder() + .setType(Secret.Type.VALUE) + .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes))) + .build() + } + + def createReferenceSecret(name: String): Secret = { + Secret.newBuilder() + .setReference(Secret.Reference.newBuilder().setName(name)) + .setType(Secret.Type.REFERENCE) + .build() + } + + val referenceSecrets: Seq[Secret] = + conf.get(secretConfig.SECRET_NAMES).getOrElse(Nil).map { s => createReferenceSecret(s) } + + val valueSecrets: Seq[Secret] = { + conf.get(secretConfig.SECRET_VALUES).getOrElse(Nil).map { s => createValueSecret(s) } + } + + if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) { + throw new SparkException("Cannot specify both value-type and reference-type secrets.") + } + + if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets + } + + private def illegalSecretInput(dest: Seq[String], secrets: Seq[Secret]): Boolean = { + if (dest.nonEmpty) { + // make sure there is a one-to-one correspondence between destinations and secrets + if (dest.length != secrets.length) { + return true + } + } + false + } + + def getSecretVolume(conf: SparkConf, secretConfig: MesosSecretConfig): List[Volume] = { + val secrets = getSecrets(conf, secretConfig) + val secretPaths: Seq[String] = + conf.get(secretConfig.SECRET_FILENAMES).getOrElse(Nil) + + if (illegalSecretInput(secretPaths, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and file paths for file-based " + + s"reference secrets got secrets $secrets, and paths $secretPaths") + } + + secrets.zip(secretPaths).map { case (s, p) => + val source = Volume.Source.newBuilder() + .setType(Volume.Source.Type.SECRET) + .setSecret(s) + Volume.newBuilder() + .setContainerPath(p) + .setSource(source) + .setMode(Volume.Mode.RO) + .build + }.toList + } + + def getSecretEnvVar(conf: SparkConf, secretConfig: MesosSecretConfig): + List[Variable] = { + val secrets = getSecrets(conf, secretConfig) + val secretEnvKeys = conf.get(secretConfig.SECRET_ENVKEYS).getOrElse(Nil) + if (illegalSecretInput(secretEnvKeys, secrets)) { + throw new SparkException( + s"Need to give equal numbers of secrets and environment keys " + + s"for environment-based reference secrets got secrets $secrets, " + + s"and keys $secretEnvKeys") + } + + secrets.zip(secretEnvKeys).map { case (s, k) => + Variable.newBuilder() + .setName(k) + .setType(Variable.Type.SECRET) + .setSecret(s) + .build + }.toList + } + private def dockerInfo( image: String, forcePullImage: Boolean, diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index ff63e3f4ccfc3..77acee608f25f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver -import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -32,6 +31,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.mesos.config class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -341,132 +341,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("Creates an env-based reference secrets.") { - setScheduler() - - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val envKey = "SECRET_ENV_KEY,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.envkeys" -> envKey), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "SECRET_ENV_KEY").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) - assert(variableOne.getSecret.getReference.getName == "/path/to/secret") - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) - assert(variableTwo.getSecret.getReference.getName == "/anothersecret") - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedRefSecrets(launchedTasks) } test("Creates an env-based value secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretValues = "user,password" - val envKeys = "USER,PASSWORD" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.envkeys" -> envKeys), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - assert(launchedTasks.head - .getCommand - .getEnvironment - .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret - val variableOne = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "USER").head - assert(variableOne.getSecret.isInitialized) - assert(variableOne.getSecret.getType == Secret.Type.VALUE) - assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) - assert(variableOne.getType == Environment.Variable.Type.SECRET) - val variableTwo = launchedTasks.head.getCommand.getEnvironment - .getVariablesList.asScala.filter(_.getName == "PASSWORD").head - assert(variableTwo.getSecret.isInitialized) - assert(variableTwo.getSecret.getType == Secret.Type.VALUE) - assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) - assert(variableTwo.getType == Environment.Variable.Type.SECRET) + val launchedTasks = launchDriverTask( + Utils.configEnvBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyEnvBasedValueSecrets(launchedTasks) } test("Creates file-based reference secrets.") { - setScheduler() - val mem = 1000 - val cpu = 1 - val secretName = "/path/to/secret,/anothersecret" - val secretPath = "/topsecret,/mypassword" - val driverDesc = new MesosDriverDescription( - "d1", - "jar", - mem, - cpu, - true, - command, - Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.names" -> secretName, - "spark.mesos.driver.secret.filenames" -> secretPath), - "s1", - new Date()) - val response = scheduler.submitDriver(driverDesc) - assert(response.success) - val offer = Utils.createOffer("o1", "s1", mem, cpu) - scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/topsecret") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) - assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + val launchedTasks = launchDriverTask( + Utils.configFileBasedRefSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedRefSecrets(launchedTasks) } test("Creates a file-based value secrets.") { + val launchedTasks = launchDriverTask( + Utils.configFileBasedValueSecrets(config.driverSecretConfig)) + Utils.verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchDriverTask(addlSparkConfVars: Map[String, String]): List[TaskInfo] = { setScheduler() val mem = 1000 val cpu = 1 - val secretValues = "user,password" - val secretPath = "/whoami,/mypassword" val driverDesc = new MesosDriverDescription( "d1", "jar", @@ -475,27 +376,14 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi true, command, Map("spark.mesos.executor.home" -> "test", - "spark.app.name" -> "test", - "spark.mesos.driver.secret.values" -> secretValues, - "spark.mesos.driver.secret.filenames" -> secretPath), + "spark.app.name" -> "test") ++ + addlSparkConfVars, "s1", new Date()) val response = scheduler.submitDriver(driverDesc) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) scheduler.resourceOffers(driver, Collections.singletonList(offer)) - val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") - val volumes = launchedTasks.head.getContainer.getVolumesList - assert(volumes.size() == 2) - val secretVolOne = volumes.get(0) - assert(secretVolOne.getContainerPath == "/whoami") - assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolOne.getSource.getSecret.getValue.getData == - ByteString.copyFrom("user".getBytes)) - val secretVolTwo = volumes.get(1) - assert(secretVolTwo.getContainerPath == "/mypassword") - assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) - assert(secretVolTwo.getSource.getSecret.getValue.getData == - ByteString.copyFrom("password".getBytes)) + Utils.verifyTaskLaunched(driver, "o1") } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 6c40792112f49..f4bd1ee9da6f7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.concurrent.duration._ -import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ @@ -38,7 +37,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor} import org.apache.spark.scheduler.cluster.mesos.Utils._ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite @@ -653,6 +652,37 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerResourcesAndVerify(2, true) } + test("Creates an env-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedRefSecrets(executorSecretConfig)) + verifyEnvBasedRefSecrets(launchedTasks) + } + + test("Creates an env-based value secrets.") { + val launchedTasks = launchExecutorTasks(configEnvBasedValueSecrets(executorSecretConfig)) + verifyEnvBasedValueSecrets(launchedTasks) + } + + test("Creates file-based reference secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedRefSecrets(executorSecretConfig)) + verifyFileBasedRefSecrets(launchedTasks) + } + + test("Creates a file-based value secrets.") { + val launchedTasks = launchExecutorTasks(configFileBasedValueSecrets(executorSecretConfig)) + verifyFileBasedValueSecrets(launchedTasks) + } + + private def launchExecutorTasks(sparkConfVars: Map[String, String]): List[TaskInfo] = { + setBackend(sparkConfVars) + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + + verifyTaskLaunched(driver, "o1") + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala index f49d7c29eda49..442c43960ec1f 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.mesos.config class MesosSchedulerBackendUtilSuite extends SparkFunSuite { @@ -26,7 +27,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a,b") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 0) @@ -37,7 +39,8 @@ class MesosSchedulerBackendUtilSuite extends SparkFunSuite { conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") conf.set("spark.mesos.executor.docker.image", "test") - val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val containerInfo = MesosSchedulerBackendUtil.buildContainerInfo( + conf) val params = containerInfo.getDocker.getParametersList assert(params.size() == 3) assert(params.get(0).getKey == "a") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 833db0c1ff334..5636ac52bd4a7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -24,9 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.{Range => MesosRange, Ranges, Scalar} import org.apache.mesos.SchedulerDriver +import org.apache.mesos.protobuf.ByteString import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ +import org.apache.spark.deploy.mesos.config.MesosSecretConfig + object Utils { val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() @@ -105,4 +108,108 @@ object Utils { def createTaskId(taskId: String): TaskID = { TaskID.newBuilder().setValue(taskId).build() } + + def configEnvBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val envKey = "SECRET_ENV_KEY,PASSWORD" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_ENVKEYS.key -> envKey + ) + } + + def verifyEnvBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "SECRET_ENV_KEY").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) + assert(variableOne.getSecret.getReference.getName == "/path/to/secret") + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE) + assert(variableTwo.getSecret.getReference.getName == "/anothersecret") + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configEnvBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val envKeys = "USER,PASSWORD" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_ENVKEYS.key -> envKeys + ) + } + + def verifyEnvBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val envVars = launchedTasks.head + .getCommand + .getEnvironment + .getVariablesList + .asScala + assert(envVars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + val variableOne = envVars.filter(_.getName == "USER").head + assert(variableOne.getSecret.isInitialized) + assert(variableOne.getSecret.getType == Secret.Type.VALUE) + assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes)) + assert(variableOne.getType == Environment.Variable.Type.SECRET) + val variableTwo = envVars.filter(_.getName == "PASSWORD").head + assert(variableTwo.getSecret.isInitialized) + assert(variableTwo.getSecret.getType == Secret.Type.VALUE) + assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) + assert(variableTwo.getType == Environment.Variable.Type.SECRET) + } + + def configFileBasedRefSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretName = "/path/to/secret,/anothersecret" + val secretPath = "/topsecret,/mypassword" + Map( + secretConfig.SECRET_NAMES.key -> secretName, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedRefSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/topsecret") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret") + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE) + assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret") + } + + def configFileBasedValueSecrets(secretConfig: MesosSecretConfig): Map[String, String] = { + val secretValues = "user,password" + val secretPath = "/whoami,/mypassword" + Map( + secretConfig.SECRET_VALUES.key -> secretValues, + secretConfig.SECRET_FILENAMES.key -> secretPath + ) + } + + def verifyFileBasedValueSecrets(launchedTasks: List[TaskInfo]): Unit = { + val volumes = launchedTasks.head.getContainer.getVolumesList + assert(volumes.size() == 2) + val secretVolOne = volumes.get(0) + assert(secretVolOne.getContainerPath == "/whoami") + assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolOne.getSource.getSecret.getValue.getData == + ByteString.copyFrom("user".getBytes)) + val secretVolTwo = volumes.get(1) + assert(secretVolTwo.getContainerPath == "/mypassword") + assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE) + assert(secretVolTwo.getSource.getSecret.getValue.getData == + ByteString.copyFrom("password".getBytes)) + } } From 8e9863531bebbd4d83eafcbc2b359b8bd0ac5734 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 26 Oct 2017 16:55:30 -0700 Subject: [PATCH 123/712] [SPARK-22366] Support ignoring missing files ## What changes were proposed in this pull request? Add a flag "spark.sql.files.ignoreMissingFiles" to parallel the existing flag "spark.sql.files.ignoreCorruptFiles". ## How was this patch tested? new unit test Author: Jose Torres Closes #19581 from joseph-torres/SPARK-22366. --- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../execution/datasources/FileScanRDD.scala | 13 +++++--- .../parquet/ParquetQuerySuite.scala | 33 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4cfe53b2c115b..21e4685fcc456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -614,6 +614,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + val MAX_RECORDS_PER_FILE = buildConf("spark.sql.files.maxRecordsPerFile") .doc("Maximum number of records to write out to a single file. " + "If this value is zero or negative, there is no limit.") @@ -1014,6 +1020,8 @@ class SQLConf extends Serializable with Logging { def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES) + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 9df20731c71d5..8731ee88f87f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -66,6 +66,7 @@ class FileScanRDD( extends RDD[InternalRow](sparkSession.sparkContext, Nil) { private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -142,7 +143,7 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - if (ignoreCorruptFiles) { + if (ignoreMissingFiles || ignoreCorruptFiles) { currentIterator = new NextIterator[Object] { // The readFunction may read some bytes before consuming the iterator, e.g., // vectorized Parquet reader. Here we use lazy val to delay the creation of @@ -158,9 +159,13 @@ class FileScanRDD( null } } catch { - // Throw FileNotFoundException even `ignoreCorruptFiles` is true - case e: FileNotFoundException => throw e - case e @ (_: RuntimeException | _: IOException) => + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => logWarning( s"Skipped the rest of the content in the corrupted file: $currentFile", e) finished = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2efff3f57d7d3..e822e40b146ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + testQuietly("Enabling/disabling ignoreMissingFiles") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + fs.delete(thirdPath, true) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 9b262f6a08c0c1b474d920d49b9fdd574c401d39 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:39:53 -0700 Subject: [PATCH 124/712] [SPARK-22356][SQL] data source table should support overlapped columns between data and partition schema ## What changes were proposed in this pull request? This is a regression introduced by #14207. After Spark 2.1, we store the inferred schema when creating the table, to avoid inferring schema again at read path. However, there is one special case: overlapped columns between data and partition. For this case, it breaks the assumption of table schema that there is on ovelap between data and partition schema, and partition columns should be at the end. The result is, for Spark 2.1, the table scan has incorrect schema that puts partition columns at the end. For Spark 2.2, we add a check in CatalogTable to validate table schema, which fails at this case. To fix this issue, a simple and safe approach is to fallback to old behavior when overlapeed columns detected, i.e. store empty schema in metastore. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19579 from cloud-fan/bug2. --- .../command/createDataSourceTables.scala | 35 +++++++++++++---- .../datasources/HadoopFsRelation.scala | 25 ++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++ .../HiveExternalCatalogVersionsSuite.scala | 38 ++++++++++++++----- 4 files changed, 89 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 9e3907996995c..306f43dc4214a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType /** * A command used to create a data source table. @@ -85,14 +86,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo } } - val newTable = table.copy( - schema = dataSource.schema, - partitionColumnNames = partitionColumnNames, - // If metastore partition management for file source tables is enabled, we start off with - // partition provider hive, but no partitions in the metastore. The user has to call - // `msck repair table` to populate the table partitions. - tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && - sessionState.conf.manageFilesourcePartitions) + val newTable = dataSource match { + // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid + // inferring the schema again at read path. However if the data source has overlapped columns + // between data and partition schema, we can't store it in metastore as it breaks the + // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store + // empty schema in metastore and infer it at runtime. Note that this also means the new + // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case. + case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty => + logWarning("It is not recommended to create a table with overlapped data and partition " + + "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " + + "which hurts performance. Please check your data files and remove the partition " + + "columns in it.") + table.copy(schema = new StructType(), partitionColumnNames = Nil) + + case _ => + table.copy( + schema = dataSource.schema, + partitionColumnNames = partitionColumnNames, + // If metastore partition management for file source tables is enabled, we start off with + // partition provider hive, but no partitions in the metastore. The user has to call + // `msck repair table` to populate the table partitions. + tracksPartitionsInCatalog = partitionColumnNames.nonEmpty && + sessionState.conf.manageFilesourcePartitions) + + } + // We will return Nil or throw exception at the beginning if the table already exists, so when // we reach here, the table should not exist and we should set `ignoreIfExists` to false. sessionState.catalog.createTable(newTable, ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 9a08524476baa..89d8a85a9cbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.{SparkSession, SQLContext} @@ -50,15 +52,22 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext - val schema: StructType = { - val getColName: (StructField => String) = - if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase - val overlappedPartCols = mutable.Map.empty[String, StructField] - partitionSchema.foreach { partitionField => - if (dataSchema.exists(getColName(_) == getColName(partitionField))) { - overlappedPartCols += getColName(partitionField) -> partitionField - } + private def getColName(f: StructField): String = { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } + + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField } + } + + val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index caf332d050d7b..5d0bba69daca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2741,4 +2741,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert (aggregateExpressions.isDefined) assert (aggregateExpressions.get.size == 2) } + + test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { + withTempPath { path => + Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j") + .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath) + withTable("t") { + sql(s"create table t using parquet options(path='${path.getCanonicalPath}')") + // We should respect the column order in data schema. + assert(spark.table("t").columns === Array("i", "p", "j")) + checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + // The DESC TABLE should report same schema as table scan. + assert(sql("desc t").select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5f8c9d5799662..6859432c406a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to // avoid downloading Spark of different versions in each run. - private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val sparkTestingDir = new File("/tmp/test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { @@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.beforeAll() val tempPyFile = File.createTempFile("test", ".py") + // scalastyle:off line.size.limit Files.write(tempPyFile.toPath, s""" |from pyspark.sql import SparkSession + |import os | |spark = SparkSession.builder.enableHiveSupport().getOrCreate() |version_index = spark.conf.get("spark.sql.test.version.index", None) | |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) | - |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ - | " using parquet as select 1 i") + |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index)) | |json_file = "${genDataDir("json_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) - |spark.sql("create table external_data_source_tbl_" + version_index + \\ - | "(i int) using json options (path '{}')".format(json_file)) + |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file)) | |parquet_file = "${genDataDir("parquet_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) - |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ - | "(i int) using parquet options (path '{}')".format(parquet_file)) + |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file)) | |json_file2 = "${genDataDir("json2_")}" + str(version_index) |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) - |spark.sql("create table external_table_without_schema_" + version_index + \\ - | " using json options (path '{}')".format(json_file2)) + |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2)) + | + |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index) + |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1")) + |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2)) | |spark.sql("create view v_{} as select 1 i".format(version_index)) """.stripMargin.getBytes("utf8")) + // scalastyle:on line.size.limit PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") @@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .enableHiveSupport() .getOrCreate() spark = session + import session.implicits._ testingVersions.indices.foreach { index => Seq( @@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // test permanent view checkAnswer(sql(s"select i from v_$index"), Row(1)) + + // SPARK-22356: overlapped columns between data and partition schema in data source tables + val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { + spark.sql("msck repair table " + tbl_with_col_overlap) + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,j,p")) + } else { + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) + } } } } From 5c3a1f3fad695317c2fff1243cdb9b3ceb25c317 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Oct 2017 17:51:16 -0700 Subject: [PATCH 125/712] [SPARK-22355][SQL] Dataset.collect is not threadsafe ## What changes were proposed in this pull request? It's possible that users create a `Dataset`, and call `collect` of this `Dataset` in many threads at the same time. Currently `Dataset#collect` just call `encoder.fromRow` to convert spark rows to objects of type T, and this encoder is per-dataset. This means `Dataset#collect` is not thread-safe, because the encoder uses a projection to output the object to a re-usable row. This PR fixes this problem, by creating a new projection when calling `Dataset#collect`, so that we have the re-usable row for each method call, instead of each Dataset. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19577 from cloud-fan/encoder. --- .../scala/org/apache/spark/sql/Dataset.scala | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b70dfc05330f8..0e23983786b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} @@ -198,15 +199,10 @@ class Dataset[T] private[sql]( */ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) - /** - * Encoder is used mostly as a container of serde expressions in Dataset. We build logical - * plans by these serde expressions and execute it within the query framework. However, for - * performance reasons we may want to use encoder as a function to deserialize internal rows to - * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its - * `fromRow` method later. - */ - private val boundEnc = - exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) + // The deserializer expression which can be used to build a projection and turn rows to objects + // of type T, after collecting rows to the driver side. + private val deserializer = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag @@ -2661,7 +2657,15 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = { withAction("toLocalIterator", queryExecution) { plan => - plan.executeToIterator().map(boundEnc.fromRow).asJava + // This projection writes output to a `InternalRow`, which means applying this projection is + // not thread-safe. Here we create the projection inside this method to make `Dataset` + // thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeToIterator().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + }.asJava } } @@ -3102,7 +3106,14 @@ class Dataset[T] private[sql]( * Collect all elements from a spark plan. */ private def collectFromPlan(plan: SparkPlan): Array[T] = { - plan.executeCollect().map(boundEnc.fromRow) + // This projection writes output to a `InternalRow`, which means applying this projection is not + // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe. + val objProj = GenerateSafeProjection.generate(deserializer :: Nil) + plan.executeCollect().map { row => + // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type + // parameter of its `get` method, so it's safe to use null here. + objProj(row).get(0, null).asInstanceOf[T] + } } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { From 17af727e38c3faaeab5b91a8cdab5f2181cf3fc4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 26 Oct 2017 23:02:46 -0700 Subject: [PATCH 126/712] [SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion ## What changes were proposed in this pull request? Adding date and timestamp support with Arrow for `toPandas()` and `pandas_udf`s. Timestamps are stored in Arrow as UTC and manifested to the user as timezone-naive localized to the Python system timezone. ## How was this patch tested? Added Scala tests for date and timestamp types under ArrowConverters, ArrowUtils, and ArrowWriter suites. Added Python tests for `toPandas()` and `pandas_udf`s with date and timestamp types. Author: Bryan Cutler Author: Takuya UESHIN Closes #18664 from BryanCutler/arrow-date-timestamp-SPARK-21375. --- python/pyspark/serializers.py | 24 +++- python/pyspark/sql/dataframe.py | 7 +- python/pyspark/sql/tests.py | 106 ++++++++++++++-- python/pyspark/sql/types.py | 36 ++++++ .../vectorized/ArrowColumnVector.java | 34 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../sql/execution/arrow/ArrowConverters.scala | 3 +- .../sql/execution/arrow/ArrowUtils.scala | 30 +++-- .../sql/execution/arrow/ArrowWriter.scala | 39 +++++- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 5 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- .../arrow/ArrowConvertersSuite.scala | 120 ++++++++++++++++-- .../sql/execution/arrow/ArrowUtilsSuite.scala | 26 +++- .../execution/arrow/ArrowWriterSuite.scala | 24 ++-- .../vectorized/ArrowColumnVectorSuite.scala | 22 ++-- .../vectorized/ColumnarBatchSuite.scala | 4 +- 17 files changed, 417 insertions(+), 73 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a0adeed994456..d7979f095da76 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,7 @@ def __repr__(self): def _create_batch(series): + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -224,12 +225,25 @@ def _create_batch(series): # If a nullable integer series has been promoted to floating point with NaNs, need to cast # NOTE: this is not necessary with Arrow >= 0.7 def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): + if type(t) == pa.TimestampType: + # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 + return _check_series_convert_timestamps_internal(s.fillna(0))\ + .values.astype('datetime64[us]', copy=False) + elif t == pa.date32(): + # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 + return s.dt.date + elif t is None or s.dtype == t.to_pandas_dtype(): return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + # Some object types don't support masks in Arrow, see ARROW-1721 + def create_array(s, t): + casted = cast_series(s, t) + mask = None if casted.dtype == 'object' else s.isnull() + return pa.Array.from_pandas(casted, mask=mask, type=t) + + arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -260,11 +274,13 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: - table = pa.Table.from_batches([batch]) - yield [c.to_pandas() for c in table.itercolumns()] + # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 + pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + yield [c for _, c in pdf.iteritems()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c0b574e2b93a1..406686e6df724 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1883,11 +1883,13 @@ def toPandas(self): import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) - return table.to_pandas() + pdf = table.to_pandas() + return _check_dataframe_localize_timestamps(pdf) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: @@ -1955,6 +1957,7 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -1965,6 +1968,8 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 + elif type(dt) == DateType: + return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 685eebcafefba..98afae662b42d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): + from datetime import datetime ReusedPySparkTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] + StructField("5_double_t", DoubleType(), True), + StructField("6_date_t", DateType(), True), + StructField("7_timestamp_t", TimestampType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + schema = StructType([StructField("decimal", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) @@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) - f = pandas_udf(lambda x: x, DateType()) + schema = StructType([StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + f = pandas_udf(lambda x: x, DecimalType()) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('dt'))).collect() + def test_vectorized_udf_null_date(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date + schema = StructType().add("date", DateType()) + data = [(date(1969, 1, 1),), + (date(2012, 2, 2),), + (None,), + (date(2100, 4, 4),)] + df = self.spark.createDataFrame(data, schema=schema) + date_f = pandas_udf(lambda t: t, returnType=DateType()) + res = df.select(date_f(col("date"))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_timestamps(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(0, datetime(1969, 1, 1, 1, 1, 1)), + (1, datetime(2012, 2, 2, 2, 2, 2)), + (2, None), + (3, datetime(2100, 4, 4, 4, 4, 4))] + df = self.spark.createDataFrame(data, schema=schema) + + # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc + f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) + + @pandas_udf(returnType=BooleanType()) + def check_data(idx, timestamp, timestamp_copy): + is_equal = timestamp.isnull() # use this array to check values are equal + for i in range(len(idx)): + # Check that timestamps are as expected in the UDF + is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1] + return is_equal + + result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() + # Check that collection values are correct + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + + def test_vectorized_udf_return_timestamp_tz(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + + @pandas_udf(returnType=TimestampType()) + def gen_timestamps(id): + ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] + return pd.Series(ts) + + result = df.withColumn("ts", gen_timestamps(col("id"))).collect() + spark_ts_t = TimestampType() + for r in result: + i, ts = r + ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() + expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) + self.assertEquals(expected, ts) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3550,8 +3634,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f65273d5f0b6c..7dd8fa04160e0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1619,11 +1619,47 @@ def to_arrow_type(dt): arrow_type = pa.decimal(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == DateType: + arrow_type = pa.date32() + elif type(dt) == TimestampType: + # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read + arrow_type = pa.timestamp('us', tz='UTC') else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type +def _check_dataframe_localize_timestamps(pdf): + """ + Convert timezone aware timestamps to timezone-naive in local time + + :param pdf: pandas.DataFrame + :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive + """ + from pandas.api.types import is_datetime64tz_dtype + for column, series in pdf.iteritems(): + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(series.dtype): + pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + return pdf + + +def _check_series_convert_timestamps_internal(s): + """ + Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + :param s: a pandas.Series + :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone + """ + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64_dtype(s.dtype): + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + elif is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert('UTC') + else: + return s + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 1f171049820b2..51ea719f8c4a6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) { accessor = new StringAccessor((NullableVarCharVector) vector); } else if (vector instanceof NullableVarBinaryVector) { accessor = new BinaryAccessor((NullableVarBinaryVector) vector); + } else if (vector instanceof NullableDateDayVector) { + accessor = new DateAccessor((NullableDateDayVector) vector); + } else if (vector instanceof NullableTimeStampMicroTZVector) { + accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) { } } + private static class DateAccessor extends ArrowVectorAccessor { + + private final NullableDateDayVector.Accessor accessor; + + DateAccessor(NullableDateDayVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final NullableTimeStampMicroTZVector.Accessor accessor; + + TimestampAccessor(NullableTimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + private static class ArrayAccessor extends ArrowVectorAccessor { private final UInt4Vector.Accessor accessor; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e23983786b08..fe4e192e43dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3154,9 +3154,11 @@ class Dataset[T] private[sql]( private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone queryExecution.toRdd.mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) + ArrowConverters.toPayloadIterator( + iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 561a067a2f81f..05ea1517fcac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -74,9 +74,10 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, + timeZoneId: String, context: TaskContext): Iterator[ArrowPayload] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 2caf1ef02909a..6ad11bda84bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.types._ @@ -31,7 +31,8 @@ object ArrowUtils { // todo: support more types. - def toArrowType(dt: DataType): ArrowType = dt match { + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -42,6 +43,13 @@ object ArrowUtils { case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => + if (timeZoneId == null) { + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + } else { + new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + } case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -58,22 +66,27 @@ object ArrowUtils { case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } - def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.toSeq.asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -94,9 +107,10 @@ object ArrowUtils { } } - def toArrowSchema(schema: StructType): Schema = { + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0b740735ffe19..e4af4f65da127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.util.DecimalUtility +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -29,8 +29,8 @@ import org.apache.spark.sql.types._ object ArrowWriter { - def create(schema: StructType): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + def create(schema: StructType, timeZoneId: String): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } @@ -55,6 +55,8 @@ object ArrowWriter { case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -69,9 +71,7 @@ object ArrowWriter { } } -class ArrowWriter( - val root: VectorSchemaRoot, - fields: Array[ArrowFieldWriter]) { +class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { def schema: StructType = StructType(fields.map { f => StructField(f.name, f.dataType, f.nullable) @@ -255,6 +255,33 @@ private[arrow] class BinaryWriter( } } +private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { + + override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class TimestampWriter( + val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 81896187ecc46..0db463a5fbd89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -79,7 +79,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index f6c03c415dc66..94c05b9b5e49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -43,7 +43,8 @@ class ArrowPythonRunner( reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], - schema: StructType) + schema: StructType, + timeZoneId: String) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -60,7 +61,7 @@ class ArrowPythonRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 5ed88ada428cb..cc93fda9f81da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,8 +94,8 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) - .compute(grouped, context.partitionId(), context) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 30422b657742c..ba2903babbba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,6 +32,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -793,6 +795,103 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "binaryData.json") } + test("date type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "date", + | "type" : { + | "name" : "date", + | "unit" : "DAY" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "date", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1, 0, 16533, 382607 ] + | } ] + | } ] + |} + """.stripMargin + + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" + val d3 = Date.valueOf("2015-04-08") + val d4 = Date.valueOf("3017-07-18") + + val df = Seq(d1, d2, d3, d4).toDF("date") + + collectAndValidate(df, json, "dateData.json") + } + + test("timestamp type conversion") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "timestamp", + | "type" : { + | "name" : "timestamp", + | "unit" : "MICROSECOND", + | "timezone" : "America/Los_Angeles" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "timestamp", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] + | } ] + | } ] + |} + """.stripMargin + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) + val ts2 = DateTimeUtils.toJavaTimestamp(0L) + val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val data = Seq(ts1, ts2, ts3, ts4) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json", "America/Los_Angeles") + } + } + test("floating-point NaN") { val json = s""" @@ -1486,15 +1585,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { @@ -1638,7 +1728,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) assert(schema.equals(outputRowIter.schema)) @@ -1657,22 +1747,24 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + private def collectAndValidate( + df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) + validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, arrowPayload: ArrowPayload, - jsonFile: File): Unit = { + jsonFile: File, + timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 638619fd39d06..d801f62b62323 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.arrow +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class ArrowUtilsSuite extends SparkFunSuite { @@ -25,7 +28,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -42,6 +45,27 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(StringType) roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) + roundtrip(DateType) + val tsExMsg = intercept[UnsupportedOperationException] { + roundtrip(TimestampType) + } + assert(tsExMsg.getMessage.contains("timeZoneId")) + } + + test("timestamp") { + + def roundtripWithTz(timeZoneId: String): Unit = { + val schema = new StructType().add("value", TimestampType) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === timeZoneId) + assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) + } + + roundtripWithTz(DateTimeUtils.defaultTimeZone().getID) + roundtripWithTz("Asia/Tokyo") + roundtripWithTz("UTC") + roundtripWithTz("America/Los_Angeles") } test("array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index e9a629315f5f4..a71e30aa3ca96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -51,6 +51,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDouble(rowId) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) + case DateType => reader.getInt(rowId) + case TimestampType => reader.getLong(rowId) } assert(value === datum) } @@ -66,12 +68,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + check(DateType, Seq(0, 1, 2, null, 4)) + check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") } test("get multiple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = false) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -88,6 +92,8 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLongs(0, data.size) case FloatType => reader.getFloats(0, data.size) case DoubleType => reader.getDoubles(0, data.size) + case DateType => reader.getInts(0, data.size) + case TimestampType => reader.getLongs(0, data.size) } assert(values === data) @@ -100,12 +106,14 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, (0 until 10).map(_.toLong)) check(FloatType, (0 until 10).map(_.toFloat)) check(DoubleType, (0 until 10).map(_.toDouble)) + check(DateType, (0 until 10)) + check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") } test("array") { val schema = new StructType() .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) @@ -144,7 +152,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested array") { val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array( @@ -195,7 +203,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("struct") { val schema = new StructType() .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) @@ -231,7 +239,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested struct") { val schema = new StructType().add("struct", new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index d24a9e1f4bd16..068a17bf772e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -29,7 +29,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBitVector] vector.allocateNew() val mutator = vector.getMutator() @@ -58,7 +58,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableTinyIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -87,7 +87,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableSmallIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -116,7 +116,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -145,7 +145,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBigIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -174,7 +174,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat4Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -203,7 +203,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat8Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -232,7 +232,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarCharVector] vector.allocateNew() val mutator = vector.getMutator() @@ -260,7 +260,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarBinaryVector] vector.allocateNew() val mutator = vector.getMutator() @@ -288,7 +288,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() val mutator = vector.getMutator() @@ -345,7 +345,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) - val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() val mutator = vector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0b179aa97c479..4cfc776e51db1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1249,11 +1249,11 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector1.allocateNew() val mutator1 = vector1.getMutator() - val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector2.allocateNew() val mutator2 = vector2.getMutator() From 36b826f5d17ae7be89135cb2c43ff797f9e7fe48 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Oct 2017 07:52:10 -0700 Subject: [PATCH 127/712] [TRIVIAL][SQL] Code cleaning in ResolveReferences ## What changes were proposed in this pull request? This PR is to clean the related codes majorly based on the today's code review on https://github.com/apache/spark/pull/19559 ## How was this patch tested? N/A Author: gatorsmile Closes #19585 from gatorsmile/trivialFixes. --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../scala/org/apache/spark/sql/Column.scala | 10 ++++----- .../spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../sql/execution/WholeStageCodegenExec.scala | 5 ++--- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d6a962a14dc9c..6384a141e83b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -783,6 +783,17 @@ class Analyzer( } } + private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(resolve(_, q)) + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -841,15 +852,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q.transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - } + q.mapExpressions(resolve(_, q)) } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8468a8a96349a..92988680871a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -44,7 +44,7 @@ private[sql] object Column { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => a.aggregateFunction.toString - case expr => usePrettyExpression(expr).sql + case expr => toPrettySQL(expr) } } } @@ -137,7 +137,7 @@ class Column(val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) - override def toString: String = usePrettyExpression(expr).sql + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -175,7 +175,7 @@ class Column(val expr: Expression) extends Logging { case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() + case _ => Alias(expr, toPrettySQL(expr))() } case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -184,7 +184,7 @@ class Column(val expr: Expression) extends Logging { // Wait until the struct is resolved. This will generate a nicer looking alias. case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6b45790d5ff6e..21e94fa8bb0b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,7 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e37d133ff336a..286cb3bb0767c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -521,10 +521,9 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) - case j @ SortMergeJoinExec(_, _, _, _, left, right) => + case j: SortMergeJoinExec => // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) + j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child)))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } From b3d8fc3dc458d42cf11d961762ce99f551f68548 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Oct 2017 13:43:09 -0700 Subject: [PATCH 128/712] [SPARK-22226][SQL] splitExpression can create too many method calls in the outer class ## What changes were proposed in this pull request? SPARK-18016 introduced `NestedClass` to avoid that the many methods generated by `splitExpressions` contribute to the outer class' constant pool, making it growing too much. Unfortunately, despite their definition is stored in the `NestedClass`, they all are invoked in the outer class and for each method invocation, there are two entries added to the constant pool: a `Methodref` and a `Utf8` entry (you can easily check this compiling a simple sample class with `janinoc` and looking at its Constant Pool). This limits the scalability of the solution with very large methods which are split in a lot of small ones. This means that currently we are generating classes like this one: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply_862(i); nestedClassInstance.apply_863(i); ... nestedClassInstance1.apply_1612(i); nestedClassInstance1.apply_1613(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... } } ``` This PR reduce the Constant Pool size of the outer class by adding a new method to each nested class: in this method we invoke all the small methods generated by `splitExpression` in that nested class. In this way, in the outer class there is only one method invocation per nested class, reducing by orders of magnitude the entries in its constant pool because of method invocations. This means that after the patch the generated code becomes: ``` class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection { ... public UnsafeRow apply(InternalRow i) { rowWriter.zeroOutNullBytes(); apply_0(i); apply_1(i); ... nestedClassInstance.apply(i); nestedClassInstance1.apply(i); ... } ... private class NestedClass { private void apply_862(InternalRow i) { ... } private void apply_863(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_862(i); apply_863(i); ... } } private class NestedClass1 { private void apply_1612(InternalRow i) { ... } private void apply_1613(InternalRow i) { ... } ... private void apply(InternalRow i) { apply_1612(i); apply_1613(i); ... } } } ``` ## How was this patch tested? Added UT and existing UTs Author: Marco Gaido Author: Marco Gaido Closes #19480 from mgaido91/SPARK-22226. --- .../expressions/codegen/CodeGenerator.scala | 156 ++++++++++++++++-- .../expressions/CodeGenerationSuite.scala | 17 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++ 3 files changed, 167 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2cb66599076a9..58738b52b299f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -77,6 +77,22 @@ case class SubExprEliminationState(isNull: String, value: String) */ case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +/** + * The main information about a new added function. + * + * @param functionName String representing the name of the function + * @param innerClassName Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * inner class in which the function has been added. + * @param innerClassInstance Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * instance of the inner class in the outer class. + */ +private[codegen] case class NewFunctionSpec( + functionName: String, + innerClassName: Option[String], + innerClassInstance: Option[String]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -228,8 +244,8 @@ class CodegenContext { /** * Holds the class and instance names to be generated, where `OuterClass` is a placeholder * standing for whichever class is generated as the outermost class and which will contain any - * nested sub-classes. All other classes and instance names in this list will represent private, - * nested sub-classes. + * inner sub-classes. All other classes and instance names in this list will represent private, + * inner sub-classes. */ private val classes: mutable.ListBuffer[(String, String)] = mutable.ListBuffer[(String, String)](outerClassName -> null) @@ -260,8 +276,8 @@ class CodegenContext { /** * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the - * function will be inlined into a new private, nested class, and a class-qualified name for the - * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * function will be inlined into a new private, inner class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inlined to the `OuterClass` the * simple `funcName` will be returned. * * @param funcName the class-unqualified name of the function @@ -271,19 +287,27 @@ class CodegenContext { * it is eventually referenced and a returned qualified function name * cannot otherwise be accessed. * @return the name of the function, qualified by class if it will be inlined to a private, - * nested sub-class + * inner class */ def addNewFunction( funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String = { - // The number of named constants that can exist in the class is limited by the Constant Pool - // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a - // threshold of 1600k bytes to determine when a function should be inlined to a private, nested - // sub-class. + val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass) + newFunction match { + case NewFunctionSpec(functionName, None, None) => functionName + case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => + innerClassInstance + "." + functionName + } + } + + private[this] def addNewFunctionInternal( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > 1600000) { + } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -294,17 +318,23 @@ class CodegenContext { currClass() } - classSize(className) += funcCode.length - classFunctions(className) += funcName -> funcCode + addNewFunctionToClass(funcName, funcCode, className) if (className == outerClassName) { - funcName + NewFunctionSpec(funcName, None, None) } else { - - s"$classInstance.$funcName" + NewFunctionSpec(funcName, Some(className), Some(classInstance)) } } + private[this] def addNewFunctionToClass( + funcName: String, + funcCode: String, + className: String) = { + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + } + /** * Declares all function code. If the added functions are too many, split them into nested * sub-classes to avoid hitting Java compiler constant pool limitation. @@ -738,7 +768,7 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow - * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions @@ -801,10 +831,90 @@ class CodegenContext { | ${makeSplitFunction(body)} |} """.stripMargin - addNewFunction(name, code) + addNewFunctionInternal(name, code, inlineToOuterClass = false) } - foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) + val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty) + + val argsString = arguments.map(_._2).mkString(", ") + val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)") + + val innerClassFunctionCalls = generateInnerClassesFunctionCalls( + innerClassFunctions, + func, + arguments, + returnType, + makeSplitFunction, + foldFunctions) + + foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls) + } + } + + /** + * Here we handle all the methods which have been added to the inner classes and + * not to the outer class. + * Since they can be many, their direct invocation in the outer class adds many entries + * to the outer class' constant pool. This can cause the constant pool to past JVM limit. + * Moreover, this can cause also the outer class method where all the invocations are + * performed to grow beyond the 64k limit. + * To avoid these problems, we group them and we call only the grouping methods in the + * outer class. + * + * @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes + * @param funcName the split function name base. + * @param arguments the list of (type, name) of the arguments of the split function. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + * @return an [[Iterable]] containing the methods' invocations + */ + private def generateInnerClassesFunctionCalls( + functions: Seq[NewFunctionSpec], + funcName: String, + arguments: Seq[(String, String)], + returnType: String, + makeSplitFunction: String => String, + foldFunctions: Seq[String] => String): Iterable[String] = { + val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]] + functions.foreach(f => { + val key = (f.innerClassName.get, f.innerClassInstance.get) + val value = f.functionName +: innerClassToFunctions.getOrElse(key, Seq.empty[String]) + innerClassToFunctions.put(key, value) + }) + + val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val argInvocationString = arguments.map(_._2).mkString(", ") + + innerClassToFunctions.flatMap { + case ((innerClassName, innerClassInstance), innerClassFunctions) => + // for performance reasons, the functions are prepended, instead of appended, + // thus here they are in reversed order + val orderedFunctions = innerClassFunctions.reverse + if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + // Adding a new function to each inner class which contains the invocation of all the + // ones which have been added to that inner class. For example, + // private class NestedClass { + // private void apply_862(InternalRow i) { ... } + // private void apply_863(InternalRow i) { ... } + // ... + // private void apply(InternalRow i) { + // apply_862(i); + // apply_863(i); + // ... + // } + // } + val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)")) + val code = s""" + |private $returnType $funcName($argDefinitionString) { + | ${makeSplitFunction(body)} + |} + """.stripMargin + addNewFunctionToClass(funcName, code, innerClassName) + Seq(s"$innerClassInstance.$funcName($argInvocationString)") + } else { + orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)") + } } } @@ -1013,6 +1123,16 @@ object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + // This is the threshold over which the methods in an inner class are grouped in a single + // method which is going to be called by the outer class instead of the many small ones + val MERGE_SPLIT_METHODS_THRESHOLD = 3 + + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1000k bytes to determine when a function should be inlined to a private, inner + // class. + val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ea0bec145481..1e6f7b65e7e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -201,6 +201,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22226: group splitted expressions into one method per nested class") { + val length = 10000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2017-10-10 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2017-10-10 07:00:00"))) + + if (actual != expected) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 473c355cf3c7f..17c88b0690800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2106,6 +2106,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) + val df = sqlContext.createDataFrame(input, StructType( + (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) + val newCols = (1 to colNumber).flatMap { colIndex => + Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), + expr(s"sqrt(_$colIndex)")) + } + df.select(newCols: _*).collect() + } + test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") From 20eb95e5e9c562261b44e4e47cad67a31390fa59 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 27 Oct 2017 15:19:27 -0700 Subject: [PATCH 129/712] [SPARK-21911][ML][PYSPARK] Parallel Model Evaluation for ML Tuning in PySpark ## What changes were proposed in this pull request? Add parallelism support for ML tuning in pyspark. ## How was this patch tested? Test updated. Author: WeichenXu Closes #19122 from WeichenXu123/par-ml-tuning-py. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- .../ml/tuning/TrainValidationSplitSuite.scala | 4 +- python/pyspark/ml/tests.py | 39 +++++++++ python/pyspark/ml/tuning.py | 86 ++++++++++++------- 4 files changed, 96 insertions(+), 37 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a01744f7b67fd..853eeb39bf8df 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -137,8 +137,8 @@ class CrossValidatorSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.avgMetrics.sorted - val parallelMetrics = cvParallelModel.avgMetrics.sorted + val serialMetrics = cvSerialModel.avgMetrics + val parallelMetrics = cvParallelModel.avgMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 2ed4fbb601b61..f8d9c66be2c40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -138,8 +138,8 @@ class TrainValidationSplitSuite cv.setParallelism(2) val cvParallelModel = cv.fit(dataset) - val serialMetrics = cvSerialModel.validationMetrics.sorted - val parallelMetrics = cvParallelModel.validationMetrics.sorted + val serialMetrics = cvSerialModel.validationMetrics + val parallelMetrics = cvParallelModel.validationMetrics assert(serialMetrics === parallelMetrics) val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 8b8bcc7b13a38..2f1f3af957e4d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -836,6 +836,27 @@ def test_save_load_simple_estimator(self): loadedModel = CrossValidatorModel.load(cvModelPath) self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -986,6 +1007,24 @@ def test_save_load_simple_estimator(self): loadedModel = TrainValidationSplitModel.load(tvsModelPath) self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 00c348aa9f7de..47351133524e7 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import itertools import numpy as np +from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasSeed +from pyspark.ml.param.shared import HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -170,7 +170,7 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -193,7 +193,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> cvModel = cv.fit(dataset) >>> cvModel.avgMetrics[0] 0.5 @@ -208,23 +209,23 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None) + seed=None, parallelism=1) """ super(CrossValidator, self).__init__() - self._setDefault(numFolds=3) + self._setDefault(numFolds=3, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None): + seed=None, parallelism=1): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None): + seed=None, parallelism=1): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -255,18 +256,27 @@ def _fit(self, dataset): randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + for i in range(nFolds): validateLB = i * h validateUB = (i + 1) * h condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric/nFolds + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + currentFoldMetrics = pool.map(singleTrain, epm) + for j in range(numModels): + metrics[j] += (currentFoldMetrics[j] / nFolds) + validation.unpersist() + train.unpersist() if eva.isLargerBetter(): bestIndex = np.argmax(metrics) @@ -316,9 +326,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed) + numFolds=numFolds, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -337,6 +348,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj @@ -427,7 +439,7 @@ def _to_java(self): return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): """ .. note:: Experimental @@ -448,7 +460,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() - >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + ... parallelism=2) >>> tvsModel = tvs.fit(dataset) >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.8333... @@ -461,23 +474,23 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None) + parallelism=1, seed=None) """ super(TrainValidationSplit, self).__init__() - self._setDefault(trainRatio=0.75) + self._setDefault(trainRatio=0.75, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - seed=None): + parallelism=1, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - seed=None): + parallelism=1, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -506,15 +519,20 @@ def _fit(self, dataset): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) - validation = df.filter(condition) - train = df.filter(~condition) - models = est.fit(train, epm) - for j in range(numModels): - model = models[j] - metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + validation = df.filter(condition).cache() + train = df.filter(~condition).cache() + + def singleTrain(paramMap): + model = est.fit(train, paramMap) + metric = eva.evaluate(model.transform(validation, paramMap)) + return metric + + pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + metrics = pool.map(singleTrain, epm) + train.unpersist() + validation.unpersist() + if eva.isLargerBetter(): bestIndex = np.argmax(metrics) else: @@ -563,9 +581,10 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() + parallelism = java_stage.getParallelism() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed) + trainRatio=trainRatio, seed=seed, parallelism=parallelism) py_stage._resetUid(java_stage.uid()) return py_stage @@ -584,6 +603,7 @@ def _to_java(self): _java_obj.setEstimator(estimator) _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) + _java_obj.setParallelism(self.getParallelism()) return _java_obj From 01f6ba0e7a12ef818d56e7d5b1bd889b79f2b57c Mon Sep 17 00:00:00 2001 From: Sathiya Date: Fri, 27 Oct 2017 18:57:08 -0700 Subject: [PATCH 130/712] [SPARK-22181][SQL] Adds ReplaceExceptWithFilter rule ## What changes were proposed in this pull request? Adds a new optimisation rule 'ReplaceExceptWithNotFilter' that replaces Except logical with Filter operator and schedule it before applying 'ReplaceExceptWithAntiJoin' rule. This way we can avoid expensive join operation if one or both of the datasets of the Except operation are fully derived out of Filters from a same parent. ## How was this patch tested? The patch is tested locally using spark-shell + unit test. Author: Sathiya Closes #19451 from sathiyapk/SPARK-22181-optimize-exceptWithFilter. --- .../sql/catalyst/expressions/subquery.scala | 10 ++ .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/ReplaceExceptWithFilter.scala | 101 +++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 15 +++ .../optimizer/ReplaceOperatorSuite.scala | 106 +++++++++++++++++- .../resources/sql-tests/inputs/except.sql | 57 ++++++++++ .../sql-tests/results/except.sql.out | 105 +++++++++++++++++ 7 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/except.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/except.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index c6146042ef1a6..6acc87a3e7367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -89,6 +89,16 @@ object SubqueryExpression { case _ => false }.isDefined } + + /** + * Returns true when an expression contains a subquery + */ + def hasSubquery(e: Expression): Boolean = { + e.find { + case _: SubqueryExpression => true + case _ => false + }.isDefined + } } object SubExprUtils extends PredicateHelper { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d829e01441dcc..3273a61dc7b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -76,6 +76,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala new file mode 100644 index 0000000000000..89bfcee078fba --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * If one or both of the datasets in the logical [[Except]] operator are purely transformed using + * [[Filter]], this rule will replace logical [[Except]] operator with a [[Filter]] operator by + * flipping the filter condition of the right child. + * {{{ + * SELECT a1, a2 FROM Tab1 WHERE a2 = 12 EXCEPT SELECT a1, a2 FROM Tab1 WHERE a1 = 5 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 WHERE a2 = 12 AND (a1 is null OR a1 <> 5) + * }}} + * + * Note: + * Before flipping the filter condition of the right node, we should: + * 1. Combine all it's [[Filter]]. + * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + */ +object ReplaceExceptWithFilter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.conf.replaceExceptWithFilter) { + return plan + } + + plan.transform { + case Except(left, right) if isEligible(left, right) => + Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + } + } + + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + val filterCondition = + InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition + + val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap + + filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + } + + // TODO: This can be further extended in the future. + private def isEligible(left: LogicalPlan, right: LogicalPlan): Boolean = (left, right) match { + case (_, right @ (Project(_, _: Filter) | Filter(_, _))) => verifyConditions(left, right) + case _ => false + } + + private def verifyConditions(left: LogicalPlan, right: LogicalPlan): Boolean = { + val leftProjectList = projectList(left) + val rightProjectList = projectList(right) + + left.output.size == left.output.map(_.name).distinct.size && + left.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + right.find(_.expressions.exists(SubqueryExpression.hasSubquery)).isEmpty && + Project(leftProjectList, nonFilterChild(skipProject(left))).sameResult( + Project(rightProjectList, nonFilterChild(skipProject(right)))) + } + + private def projectList(node: LogicalPlan): Seq[NamedExpression] = node match { + case p: Project => p.projectList + case x => x.output + } + + private def skipProject(node: LogicalPlan): LogicalPlan = node match { + case p: Project => p.child + case x => x + } + + private def nonFilterChild(plan: LogicalPlan) = plan.find(!_.isInstanceOf[Filter]).getOrElse { + throw new IllegalStateException("Leaf node is expected") + } + + private def combineFilters(plan: LogicalPlan): LogicalPlan = { + @tailrec + def iterate(plan: LogicalPlan, acc: LogicalPlan): LogicalPlan = { + if (acc.fastEquals(plan)) acc else iterate(acc, CombineFilters(acc)) + } + iterate(plan, CombineFilters(plan)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21e4685fcc456..5203e8833fbbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -948,6 +948,19 @@ object SQLConf { .intConf .createWithDefault(10000) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") + .internal() + .doc("When true, the apply function of the rule verifies whether the right node of the" + + " except operation is of type Filter or Project followed by Filter. If yes, the rule" + + " further verifies 1) Excluding the filter operations from the right (as well as the" + + " left node, if any) on the top, whether both the nodes evaluates to a same result." + + " 2) The left and right nodes don't contain any SubqueryExpressions. 3) The output" + + " column names of the left node are distinct. If all the conditions are met, the" + + " rule will replace the except operation with a Filter by flipping the filter" + + " condition(s) of the right node.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1233,6 +1246,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 85988d2fb948c..0fa1aaeb9e164 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -31,6 +32,7 @@ class ReplaceOperatorSuite extends PlanTest { val batches = Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, + ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, ReplaceDeduplicateWithAggregate) :: Nil @@ -50,6 +52,108 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter while both the nodes are of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) + + val query = Except(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), table1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while both the nodes are of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), table1) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Project(Seq(attributeA, attributeB), table1))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while only right node is of type Project") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + val table3 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA >= 2 && attributeB < 1)), + Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Except with Filter while left node is Project and right node is Filter") { + val attributeA = 'a.int + val attributeB = 'b.int + + val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) + val table2 = Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))) + val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) + + val query = Except(table2, table3) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && + (attributeA === 1 && attributeB === 2)), + Project(Seq(attributeA, attributeB), + Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Except with Left-anti Join") { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) diff --git a/sql/core/src/test/resources/sql-tests/inputs/except.sql b/sql/core/src/test/resources/sql-tests/inputs/except.sql new file mode 100644 index 0000000000000..1d579e65f3473 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except.sql @@ -0,0 +1,57 @@ +-- Tests different scenarios of except operation +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v); + +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v); + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t1 EXCEPT SELECT * FROM t2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3; + + +-- Except operation that will be replaced by Filter: SPARK-22181 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1; + + +-- Except operation that will be replaced by left anti join +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one'; + + +-- Except operation that will be replaced by left anti join +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k); diff --git a/sql/core/src/test/resources/sql-tests/results/except.sql.out b/sql/core/src/test/resources/sql-tests/results/except.sql.out new file mode 100644 index 0000000000000..c9b712d4d2949 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except.sql.out @@ -0,0 +1,105 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +create temporary view t1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", NULL) + as t1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5), + ("one", NULL), + (NULL, 5) + as t2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM t1 EXCEPT SELECT * FROM t2 +-- !query 2 schema +struct +-- !query 2 output +three 3 +two 2 + + +-- !query 3 +SELECT * FROM t1 EXCEPT SELECT * FROM t1 where v <> 1 and v <> 2 +-- !query 3 schema +struct +-- !query 3 output +one 1 +one NULL +two 2 + + +-- !query 4 +SELECT * FROM t1 where v <> 1 and v <> 22 EXCEPT SELECT * FROM t1 where v <> 2 and v >= 3 +-- !query 4 schema +struct +-- !query 4 output +two 2 + + +-- !query 5 +SELECT t1.* FROM t1, t2 where t1.k = t2.k +EXCEPT +SELECT t1.* FROM t1, t2 where t1.k = t2.k and t1.k != 'one' +-- !query 5 schema +struct +-- !query 5 output +one 1 +one NULL + + +-- !query 6 +SELECT * FROM t2 where v >= 1 and v <> 22 EXCEPT SELECT * FROM t1 +-- !query 6 schema +struct +-- !query 6 output +NULL 5 +one 5 + + +-- !query 7 +SELECT (SELECT min(k) FROM t2 WHERE t2.k = t1.k) min_t2 FROM t1 +MINUS +SELECT (SELECT min(k) FROM t2) abs_min_t2 FROM t1 WHERE t1.k = 'one' +-- !query 7 schema +struct +-- !query 7 output +NULL +two + + +-- !query 8 +SELECT t1.k +FROM t1 +WHERE t1.v <= (SELECT max(t2.v) + FROM t2 + WHERE t2.k = t1.k) +MINUS +SELECT t1.k +FROM t1 +WHERE t1.v >= (SELECT min(t2.v) + FROM t2 + WHERE t2.k = t1.k) +-- !query 8 schema +struct +-- !query 8 output +two From c42d208e197ff061cc5f49c75568c047d8d1126c Mon Sep 17 00:00:00 2001 From: donnyzone Date: Fri, 27 Oct 2017 23:40:59 -0700 Subject: [PATCH 131/712] [SPARK-22333][SQL] timeFunctionCall(CURRENT_DATE, CURRENT_TIMESTAMP) has conflicts with columnReference ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-22333 In current version, users can use CURRENT_DATE() and CURRENT_TIMESTAMP() without specifying braces. However, when a table has columns named as "current_date" or "current_timestamp", it will still be parsed as function call. There are many such cases in our production cluster. We get the wrong answer due to this inappropriate behevior. In general, ColumnReference should get higher priority than timeFunctionCall. ## How was this patch tested? unit test manul test Author: donnyzone Closes #19559 from DonnyZone/master. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 +-- .../sql/catalyst/analysis/Analyzer.scala | 37 +++++++++++++- .../sql/catalyst/parser/AstBuilder.scala | 13 ----- .../parser/ExpressionParserSuite.scala | 5 -- .../resources/sql-tests/inputs/datetime.sql | 17 +++++++ .../sql-tests/results/datetime.sql.out | 49 ++++++++++++++++++- 6 files changed, 102 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 17c8404f8a79c..6fe995f650d55 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -564,8 +564,7 @@ valueExpression ; primaryExpression - : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + : CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct @@ -747,7 +746,7 @@ nonReserved | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN | UNBOUNDED | WHEN - | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP + | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | DIRECTORY | BOTH | LEADING | TRAILING ; @@ -983,8 +982,6 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; -CURRENT_DATE: 'CURRENT_DATE'; -CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6384a141e83b3..e5c93b5f0e059 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -786,7 +786,12 @@ class Analyzer( private def resolve(e: Expression, q: LogicalPlan): Expression = e match { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, q)) + .getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -925,6 +930,30 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction( + nameParts: Seq[String], + attribute: UnresolvedAttribute, + plan: LogicalPlan): Option[Expression] = { + if (nameParts.length != 1) return None + val isNamedExpression = plan match { + case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute) + case Project(projectList, _) => projectList.contains(attribute) + case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute) + case _ => false + } + val wrapper: Expression => Expression = + if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity + // support CURRENT_DATE and CURRENT_TIMESTAMP + val literalFunctions = Seq(CurrentDate(), CurrentTimestamp()) + val name = nameParts.head + val func = literalFunctions.find(e => resolver(e.prettyName, name)) + func.map(wrapper) + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -937,7 +966,11 @@ class Analyzer( expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + withPosition(u) { + plan.resolve(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, plan)) + .getOrElse(u) + } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ce367145bc637..7651d11ee65a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1234,19 +1234,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } - /** - * Create a current timestamp/date expression. These are different from regular function because - * they do not require the user to specify braces when calling them. - */ - override def visitTimeFunctionCall(ctx: TimeFunctionCallContext): Expression = withOrigin(ctx) { - ctx.name.getType match { - case SqlBaseParser.CURRENT_DATE => - CurrentDate() - case SqlBaseParser.CURRENT_TIMESTAMP => - CurrentTimestamp() - } - } - /** * Create a function database (optional) and name pair. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 76c79b3d0760c..2b9783a3295c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -592,11 +592,6 @@ class ExpressionParserSuite extends PlanTest { intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") } - test("current date/timestamp braceless expressions") { - assertEqual("current_date", CurrentDate()) - assertEqual("current_timestamp", CurrentTimestamp()) - } - test("SPARK-17364, fully qualified column name which starts with number") { assertEqual("123_", UnresolvedAttribute("123_")) assertEqual("1a.123_", UnresolvedAttribute("1a.123_")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 616b6caee3f20..adea2bfa82cd3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -8,3 +8,20 @@ select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd') select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); + +-- [SPARK-22333]: timeFunctionCall has conflicts with columnReference +create temporary view ttf1 as select * from values + (1, 2), + (2, 3) + as ttf1(current_date, current_timestamp); + +select current_date, current_timestamp from ttf1; + +create temporary view ttf2 as select * from values + (1, 2), + (2, 3) + as ttf2(a, b); + +select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; + +select a, b from ttf2 order by a, current_date; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index a28b91c77324b..7b2f46f6c2a66 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 9 -- !query 0 @@ -32,3 +32,50 @@ select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27') struct -- !query 3 output 7 5 7 NULL 6 + + +-- !query 4 +create temporary view ttf1 as select * from values + (1, 2), + (2, 3) + as ttf1(current_date, current_timestamp) +-- !query 4 schema +struct<> +-- !query 4 output + + +-- !query 5 +select current_date, current_timestamp from ttf1 +-- !query 5 schema +struct +-- !query 5 output +1 2 +2 3 + + +-- !query 6 +create temporary view ttf2 as select * from values + (1, 2), + (2, 3) + as ttf2(a, b) +-- !query 6 schema +struct<> +-- !query 6 output + + +-- !query 7 +select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2 +-- !query 7 schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean,a:int,b:int> +-- !query 7 output +true true 1 2 +true true 2 3 + + +-- !query 8 +select a, b from ttf2 order by a, current_date +-- !query 8 schema +struct +-- !query 8 output +1 2 +2 3 From d28d5732ae205771f1f443b15b10e64dcffb5ff0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Oct 2017 23:44:24 -0700 Subject: [PATCH 132/712] [SPARK-21619][SQL] Fail the execution of canonicalized plans explicitly ## What changes were proposed in this pull request? Canonicalized plans are not supposed to be executed. I ran into a case in which there's some code that accidentally calls execute on a canonicalized plan. This patch throws a more explicit exception when that happens. ## How was this patch tested? Added a test case in SparkPlanSuite. Author: Reynold Xin Closes #18828 from rxin/SPARK-21619. --- .../sql/catalyst/catalog/interface.scala | 5 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 30 +++++++++++++--- .../plans/logical/basicLogicalOperators.scala | 2 +- .../sql/catalyst/plans/logical/hints.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 4 +-- .../spark/sql/execution/SparkPlan.scala | 6 ++++ .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../execution/basicPhysicalOperators.scala | 2 +- .../spark/sql/execution/command/cache.scala | 5 ++- .../datasources/LogicalRelation.scala | 2 +- .../exchange/BroadcastExchangeExec.scala | 2 +- .../sql/execution/exchange/Exchange.scala | 2 +- .../spark/sql/execution/SparkPlanSuite.scala | 36 +++++++++++++++++++ .../hive/execution/HiveTableScanExec.scala | 4 +-- 14 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1dbae4d37d8f5..b87bbb4874670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -438,7 +438,7 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override lazy val canonicalized: HiveTableRelation = copy( + override def doCanonicalize(): HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, createTime = -1 @@ -448,7 +448,8 @@ case class HiveTableRelation( }, partitionCols = partitionCols.zipWithIndex.map { case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) - }) + } + ) override def computeStats(): Statistics = { tableMeta.stats.map(_.toPlanStats(output)).getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c7952e3ff8280..d21b4afa2f06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -180,6 +180,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. + * See [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + /** * Returns a plan where a best effort attempt has been made to transform `this` in a way * that preserves the result but removes cosmetic variations (case sensitivity, ordering for @@ -188,10 +197,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same * result. * - * Some nodes should overwrite this to provide proper canonicalize logic, but they should remove - * expressions cosmetic variations themselves. + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. + * They should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** + * Defines how the canonicalization should work for the current plan. */ - lazy val canonicalized: PlanType = { + protected def doCanonicalize(): PlanType = { val canonicalizedChildren = children.map(_.canonicalized) var id = -1 mapExpressions { @@ -213,7 +236,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.withNewChildren(canonicalizedChildren) } - /** * Returns true when the given query plan will return the same results as this query plan. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 80243d3d356ca..c2750c3079814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -760,7 +760,7 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { - override lazy val canonicalized: LogicalPlan = child.canonicalized + override def doCanonicalize(): LogicalPlan = child.canonicalized override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 29a43528124d8..cbb626590d1d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -41,7 +41,7 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output - override lazy val canonicalized: LogicalPlan = child.canonicalized + override def doCanonicalize(): LogicalPlan = child.canonicalized } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8d0fc32feac99..e9f65031143b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -139,7 +139,7 @@ case class RowDataSourceScanExec( } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. - override lazy val canonicalized: SparkPlan = + override def doCanonicalize(): SparkPlan = copy( fullOutput.map(QueryPlan.normalizeExprId(_, fullOutput)), rdd = null, @@ -522,7 +522,7 @@ case class FileSourceScanExec( } } - override lazy val canonicalized: FileSourceScanExec = { + override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, output.map(QueryPlan.normalizeExprId(_, output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2ffd948f984bf..657b265260135 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -111,6 +111,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override `doExecute`. */ final def execute(): RDD[InternalRow] = executeQuery { + if (isCanonicalizedPlan) { + throw new IllegalStateException("A canonicalized plan is not supposed to be executed.") + } doExecute() } @@ -121,6 +124,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override `doExecuteBroadcast`. */ final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { + if (isCanonicalizedPlan) { + throw new IllegalStateException("A canonicalized plan is not supposed to be executed.") + } doExecuteBroadcast() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 6de9ea0efd2c6..29b584b55972c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -286,7 +286,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Create a [[ClearCacheCommand]] logical plan. */ override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { - ClearCacheCommand + ClearCacheCommand() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d15ece304cac4..e58c3cec2df15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -350,7 +350,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override lazy val canonicalized: SparkPlan = { + override def doCanonicalize(): SparkPlan = { RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 140f920eaafae..687994d82a003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -66,10 +66,13 @@ case class UncacheTableCommand( /** * Clear all cached data from the in-memory cache. */ -case object ClearCacheCommand extends RunnableCommand { +case class ClearCacheCommand() extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.clearCache() Seq.empty[Row] } + + /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ + override def makeCopy(newArgs: Array[AnyRef]): ClearCacheCommand = ClearCacheCommand() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3e98cb28453a2..236995708a12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -35,7 +35,7 @@ case class LogicalRelation( extends LeafNode with MultiInstanceRelation { // Only care about relation when canonicalizing. - override lazy val canonicalized: LogicalPlan = copy( + override def doCanonicalize(): LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), catalogTable = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 880e18c6808b0..daea6c39624d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,7 +48,7 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override lazy val canonicalized: SparkPlan = { + override def doCanonicalize(): SparkPlan = { BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 4b52f3e4c49b0..09f79a2de0ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -50,7 +50,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan extends LeafExecNode { // Ignore this wrapper for canonicalizing. - override lazy val canonicalized: SparkPlan = child.canonicalized + override def doCanonicalize(): SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala new file mode 100644 index 0000000000000..750d9e4adf8b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +class SparkPlanSuite extends QueryTest with SharedSQLContext { + + test("SPARK-21619 execution of a canonicalized plan should fail") { + val plan = spark.range(10).queryExecution.executedPlan.canonicalized + + intercept[IllegalStateException] { plan.execute() } + intercept[IllegalStateException] { plan.executeCollect() } + intercept[IllegalStateException] { plan.executeCollectPublic() } + intercept[IllegalStateException] { plan.executeToIterator() } + intercept[IllegalStateException] { plan.executeBroadcast() } + intercept[IllegalStateException] { plan.executeTake(1) } + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 4f8dab9cd6172..7dcaf170f9693 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -203,11 +203,11 @@ case class HiveTableScanExec( } } - override lazy val canonicalized: HiveTableScanExec = { + override def doCanonicalize(): HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), - relation.canonicalized, + relation.canonicalized.asInstanceOf[HiveTableRelation], QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } From 683ffe0620e69fd6e9f92c1037eef7996029aba8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 28 Oct 2017 21:47:15 +0900 Subject: [PATCH 133/712] [SPARK-22335][SQL] Clarify union behavior on Dataset of typed objects in the document ## What changes were proposed in this pull request? Seems that end users can be confused by the union's behavior on Dataset of typed objects. We can clarity it more in the document of `union` function. ## How was this patch tested? Only document change. Author: Liang-Chi Hsieh Closes #19570 from viirya/SPARK-22335. --- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fe4e192e43dfe..bd99ec52ce93f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1747,7 +1747,26 @@ class Dataset[T] private[sql]( * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does * deduplication of elements), use this function followed by a [[distinct]]. * - * Also as standard in SQL, this function resolves columns by position (not by name). + * Also as standard in SQL, this function resolves columns by position (not by name): + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") + * df1.union(df2).show + * + * // output: + * // +----+----+----+ + * // |col0|col1|col2| + * // +----+----+----+ + * // | 1| 2| 3| + * // | 4| 5| 6| + * // +----+----+----+ + * }}} + * + * Notice that the column positions in the schema aren't necessarily matched with the + * fields in the strongly typed objects in a Dataset. This function resolves columns + * by their positions in the schema, not the fields in the strongly typed objects. Use + * [[unionByName]] to resolve columns by field name in the typed objects. * * @group typedrel * @since 2.0.0 From 4c5269f1aa529e6a397b68d6dc409d89e32685bd Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 28 Oct 2017 18:33:09 +0100 Subject: [PATCH 134/712] [SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver. ## What changes were proposed in this pull request? `ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #19587 from ueshin/issues/SPARK-22370. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++++ .../python/ArrowEvalPythonExec.scala | 6 ++++-- .../python/FlatMapGroupsInPandasExec.scala | 3 ++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98afae662b42d..8ed37c9da98b8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3476,6 +3476,26 @@ def gen_timestamps(id): expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) self.assertEquals(expected, ts) + def test_vectorized_udf_check_config(self): + from pyspark.sql.functions import pandas_udf, col + orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) + try: + df = self.spark.range(10, numPartitions=1) + + @pandas_udf(returnType=LongType()) + def check_records_per_batch(x): + self.assertTrue(x.size <= 3) + return x + + result = df.select(check_records_per_batch(col("id"))) + self.assertEquals(df.collect(), result.collect()) + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + else: + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d21b4afa2f06c..ddf2cbf2ab911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => + /** + * The active config object within the current scope. + * Note that if you want to refer config values during execution, you have to capture them + * in Driver and use the captured values in Executors. + * See [[SQLConf.get]] for more information. + */ def conf: SQLConf = SQLConf.get def output: Seq[Attribute] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0db463a5fbd89..bcda2dae92e53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + protected override def evaluate( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, @@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) - val batchSize = conf.arrowMaxRecordsPerBatch // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index cc93fda9f81da..e1e04e34e0c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec( val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) val schema = StructType(child.schema.drop(groupingAttributes.length)) + val sessionLocalTimeZone = conf.sessionLocalTimeZone inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { @@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) From e80da8129a6b8ebaeac0eeac603ddc461144aec3 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 28 Oct 2017 17:20:35 -0700 Subject: [PATCH 135/712] [MINOR] Remove false comment from planStreamingAggregation ## What changes were proposed in this pull request? AggUtils.planStreamingAggregation has some comments about DISTINCT aggregates, while streaming aggregation does not support DISTINCT. This seems to have been wrongly copy-pasted over. ## How was this patch tested? Only a comment change. Author: Juliusz Sompolski Closes #18937 from juliuszsompolski/streaming-agg-doc. --- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 12f8cffb6774a..ebbdf1aaa024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -263,9 +263,6 @@ object AggUtils { val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. createAggregate( groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, From 7fdacbc77bbcf98c2c045a1873e749129769dcc0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 28 Oct 2017 18:24:18 -0700 Subject: [PATCH 136/712] [SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/17075 , to fix the bug in codegen path. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #19576 from cloud-fan/bug. --- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../expressions/decimalExpressions.scala | 2 +- .../expressions/mathExpressions.scala | 10 ++---- .../org/apache/spark/sql/types/Decimal.scala | 31 ++++++++++--------- .../apache/spark/sql/types/DecimalSuite.scala | 2 +- .../apache/spark/sql/MathFunctionsSuite.scala | 12 +++++++ 7 files changed, 36 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4ebdb139fe0f..474ec592201d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,7 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - decimal.toPrecision(dataType.precision, dataType.scale).orNull + decimal.toPrecision(dataType.precision, dataType.scale) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d949b8f1d6696..bc809f559d586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String /** * Create new `Decimal` with precision and scale given in `decimalType` (if any), * returning null if it overflows or creating a new `value` and returning it if successful. - * */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale).orNull + value.toPrecision(decimalType.precision, decimalType.scale) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c2211ae5d594b..752dea23e1f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 547d5be0e908e..d8dc0862f1141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, s, mode).orNull + decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, - java.math.BigDecimal.${modeStr})) { - ${ev.value} = ${ce.value}; - } else { - ${ev.isNull} = true; - }""" + ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); + ${ev.isNull} = ${ev.value} == null;""" case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1f1fb51addfd8..6da4f28b12962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } - def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { - case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) - case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) - } - /** * Create new `Decimal` with given precision and scale. * - * @return `Some(decimal)` if successful or `None` if overflow would occur + * @return a non-null `Decimal` value if successful or `null` if overflow would occur. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + if (copy.changePrecision(precision, scale, roundMode)) copy else null } /** @@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * * @return true if successful, false if overflow would occur */ - private[sql] def changePrecision(precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value): Boolean = { + private[sql] def changePrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_FLOOR) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( - throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) + val res = toPrecision(newPrecision, 0, ROUND_CEILING) + if (res == null) { + throw new AnalysisException(s"Overflow when setting precision to $newPrecision") + } + res } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 3193d1320ad9d..10de90c6a44ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") - val copy = d.toPrecision(10, 0, mode).orNull + val copy = d.toPrecision(10, 0, mode) assert(copy !== null) assert(d.ne(copy)) assert(d === copy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index c2d08a06569bf..5be8c581e9ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with table columns") { + withTable("t") { + Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t") + checkAnswer( + sql("select i, round(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + checkAnswer( + sql("select i, bround(i) from t"), + Seq(Row(BigDecimal("5.9"), BigDecimal("6")))) + } + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } From 544a1ba678810b331d78fe9e63c7bb2342ab3d99 Mon Sep 17 00:00:00 2001 From: Xin Lu Date: Sun, 29 Oct 2017 15:29:23 +0900 Subject: [PATCH 137/712] =?UTF-8?q?[SPARK-22375][TEST]=20Test=20script=20c?= =?UTF-8?q?an=20fail=20if=20eggs=20are=20installed=20by=20set=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …up.py during test process ## What changes were proposed in this pull request? Ignore the python/.eggs folder when running lint-python ## How was this patch tested? 1) put a bad python file in python/.eggs and ran the original script. results were: xins-MBP:spark xinlu$ dev/lint-python PEP8 checks failed. ./python/.eggs/worker.py:33:4: E121 continuation line under-indented for hanging indent ./python/.eggs/worker.py:34:5: E131 continuation line unaligned for hanging indent 2) test same situation with change: xins-MBP:spark xinlu$ dev/lint-python PEP8 checks passed. The sphinx-build command was not found. Skipping pydoc checks for now Author: Xin Lu Closes #19597 from xynny/SPARK-22375. --- dev/tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/tox.ini b/dev/tox.ini index eeeb637460cfb..eb8b1eb2c2886 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pep8] ignore=E402,E731,E241,W503,E226 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* From bc7ca9786e162e33f29d57c4aacb830761b97221 Mon Sep 17 00:00:00 2001 From: Jen-Ming Chung Date: Sun, 29 Oct 2017 18:11:48 +0100 Subject: [PATCH 138/712] [SPARK-22291][SQL] Conversion error when transforming array types of uuid, inet and cidr to StingType in PostgreSQL ## What changes were proposed in this pull request? This PR fixes the conversion error when reads data from a PostgreSQL table that contains columns of `uuid[]`, `inet[]` and `cidr[]` data types. For example, create a table with the uuid[] data type, and insert the test data. ```SQL CREATE TABLE users ( id smallint NOT NULL, name character varying(50), user_ids uuid[], PRIMARY KEY (id) ) INSERT INTO users ("id", "name","user_ids") VALUES (1, 'foo', ARRAY ['7be8aaf8-650e-4dbb-8186-0a749840ecf2' ,'205f9bfc-018c-4452-a605-609c0cfad228']::UUID[] ) ``` Then it will throw the following exceptions when trying to load the data. ``` java.lang.ClassCastException: [Ljava.util.UUID; cannot be cast to [Ljava.lang.String; at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$14.apply(JdbcUtils.scala:459) at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$14.apply(JdbcUtils.scala:458) ... ``` ## How was this patch tested? Added test in `PostgresIntegrationSuite`. Author: Jen-Ming Chung Closes #19567 from jmchung/SPARK-22291. --- .../sql/jdbc/PostgresIntegrationSuite.scala | 37 ++++++++++++++++++- .../datasources/jdbc/JdbcUtils.scala | 5 ++- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index eb3c458360e7b..48aba90afc787 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -60,7 +60,22 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)") .executeUpdate() conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " + - "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', TIME WITH TIME ZONE '17:22:31.949271+00')") + "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', " + + "TIME WITH TIME ZONE '17:22:31.949271+00')") + .executeUpdate() + + conn.prepareStatement("CREATE TABLE st_with_array (c0 uuid, c1 inet, c2 cidr," + + "c3 json, c4 jsonb, c5 uuid[], c6 inet[], c7 cidr[], c8 json[], c9 jsonb[])") + .executeUpdate() + conn.prepareStatement("INSERT INTO st_with_array VALUES ( " + + "'0a532531-cdf1-45e3-963d-5de90b6a30f1', '172.168.22.1', '192.168.100.128/25', " + + """'{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}', """ + + "ARRAY['7be8aaf8-650e-4dbb-8186-0a749840ecf2'," + + "'205f9bfc-018c-4452-a605-609c0cfad228']::uuid[], ARRAY['172.16.0.41', " + + "'172.16.0.42']::inet[], ARRAY['192.168.0.0/24', '10.1.0.0/16']::cidr[], " + + """ARRAY['{"a": "foo", "b": "bar"}', '{"a": 1, "b": 2}']::json[], """ + + """ARRAY['{"a": 1, "b": 2, "c": 3}']::jsonb[])""" + ) .executeUpdate() } @@ -134,11 +149,29 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(1).dataType == ShortType) } - test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE should be recognized") { + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + + "should be recognized") { val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types(1).equals("class java.sql.Timestamp")) assert(types(2).equals("class java.sql.Timestamp")) } + + test("SPARK-22291: Conversion error when transforming array types of " + + "uuid, inet and cidr to StingType in PostgreSQL") { + val df = sqlContext.read.jdbc(jdbcUrl, "st_with_array", new Properties) + val rows = df.collect() + assert(rows(0).getString(0) == "0a532531-cdf1-45e3-963d-5de90b6a30f1") + assert(rows(0).getString(1) == "172.168.22.1") + assert(rows(0).getString(2) == "192.168.100.128/25") + assert(rows(0).getString(3) == "{\"a\": \"foo\", \"b\": \"bar\"}") + assert(rows(0).getString(4) == "{\"a\": 1, \"b\": 2}") + assert(rows(0).getSeq(5) == Seq("7be8aaf8-650e-4dbb-8186-0a749840ecf2", + "205f9bfc-018c-4452-a605-609c0cfad228")) + assert(rows(0).getSeq(6) == Seq("172.16.0.41", "172.16.0.42")) + assert(rows(0).getSeq(7) == Seq("192.168.0.0/24", "10.1.0.0/16")) + assert(rows(0).getSeq(8) == Seq("""{"a": "foo", "b": "bar"}""", """{"a": 1, "b": 2}""")) + assert(rows(0).getSeq(9) == Seq("""{"a": 1, "b": 2, "c": 3}""")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 9debc4ff82748..75c94fc486493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -456,8 +456,9 @@ object JdbcUtils extends Logging { case StringType => (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) + // some underling types are not String such as uuid, inet, cidr, etc. + array.asInstanceOf[Array[java.lang.Object]] + .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) case DateType => (array: Object) => From 659acf18daf0d91fc0595227b7e29d732b99f4aa Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 29 Oct 2017 10:37:25 -0700 Subject: [PATCH 139/712] Revert "[SPARK-22308] Support alternative unit testing styles in external applications" This reverts commit 592cfeab9caeff955d115a1ca5014ede7d402907. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 45 ----- .../spark/sql/test/GenericFunSpecSuite.scala | 47 ----- .../spark/sql/test/GenericWordSpecSuite.scala | 51 ------ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++---------- .../spark/sql/test/SharedSQLContext.scala | 84 ++++++++- .../spark/sql/test/SharedSparkSession.scala | 119 ------------ 8 files changed, 165 insertions(+), 381 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 1aa1c421d792e..6aedcb1271ff6 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,23 +29,10 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) - /** - * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in - * test using styles other than FunSuite, there is often code that relies on the session between - * test group constructs and the actual tests, which may need this session. It is purely a - * semantic difference, but semantically, it makes more sense to call 'initializeContext' between - * a 'describe' and an 'it' call than it does to call 'beforeAll'. - */ - protected def initializeContext(): Unit = { - if (null == _sc) { - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) - } - } - override def beforeAll() { super.beforeAll() - initializeContext() + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d54360..10bdfafd6f933 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.Suite - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -31,13 +29,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PlanTestBase - -/** - * Provides helper methods for comparing plans, but without the overhead of - * mandating a FunSuite. - */ -trait PlanTestBase extends PredicateHelper { self: Suite => +trait PlanTest extends SparkFunSuite with PredicateHelper { // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala deleted file mode 100644 index 6179585a0d39a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.FlatSpec - -/** - * The purpose of this suite is to make sure that generic FlatSpec-based scala - * tests work with a shared spark session - */ -class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - "A Simple Dataset" should "have the specified number of elements" in { - assert(8 === ds.count) - } - it should "have the specified number of unique elements" in { - assert(8 === ds.distinct.count) - } - it should "have the specified number of elements in each column" in { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - it should "have the correct number of distinct elements in each column" in { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala deleted file mode 100644 index 15139ee8b3047..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.FunSpec - -/** - * The purpose of this suite is to make sure that generic FunSpec-based scala - * tests work with a shared spark session - */ -class GenericFunSpecSuite extends FunSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - describe("Simple Dataset") { - it("should have the specified number of elements") { - assert(8 === ds.count) - } - it("should have the specified number of unique elements") { - assert(8 === ds.distinct.count) - } - it("should have the specified number of elements in each column") { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - it("should have the correct number of distinct elements in each column") { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala deleted file mode 100644 index b6548bf95fec8..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import org.scalatest.WordSpec - -/** - * The purpose of this suite is to make sure that generic WordSpec-based scala - * tests work with a shared spark session - */ -class GenericWordSpecSuite extends WordSpec with SharedSparkSession { - import testImplicits._ - initializeSession() - val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS - - "A Simple Dataset" when { - "looked at as complete rows" should { - "have the specified number of elements" in { - assert(8 === ds.count) - } - "have the specified number of unique elements" in { - assert(8 === ds.distinct.count) - } - } - "refined to specific columns" should { - "have the specified number of elements in each column" in { - assert(8 === ds.select("_1").count) - assert(8 === ds.select("_2").count) - } - "have the correct number of distinct elements in each column" in { - assert(8 === ds.select("_1").distinct.count) - assert(4 === ds.select("_2").distinct.count) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b4248b74f50ab..a14a1441a4313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,17 +36,14 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.UninterruptibleThread -import org.apache.spark.util.Utils +import org.apache.spark.util.{UninterruptibleThread, Utils} /** - * Helper trait that should be extended by all SQL test suites within the Spark - * code base. + * Helper trait that should be extended by all SQL test suites. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -55,99 +52,17 @@ import org.apache.spark.util.Utils * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } -} - -/** - * Helper trait that can be extended by all external SQL test suites. - * - * This allows subclasses to plugin a custom `SQLContext`. - * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. - * - * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is - * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. - */ -private[sql] trait SQLTestUtilsBase - extends Eventually +private[sql] trait SQLTestUtils + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData - with PlanTestBase { self: Suite => + with PlanTest { self => protected def sparkContext = spark.sparkContext + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -162,6 +77,21 @@ private[sql] trait SQLTestUtilsBase protected override def _sqlContext: SQLContext = self.spark.sqlContext } + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -367,6 +297,61 @@ private[sql] trait SQLTestUtilsBase Dataset.ofRows(spark, plan) } + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..cd8d0708d8a32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,86 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. + */ + protected override def beforeAll(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala deleted file mode 100644 index e0568a3c5c99f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.concurrent.duration._ - -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSparkSession - extends SQLTestUtilsBase - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - } - - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} From 1fe27612d7bcb8b6478a36bc16ddd4802e4ee2fc Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 29 Oct 2017 18:53:47 -0700 Subject: [PATCH 140/712] [SPARK-22344][SPARKR] Set java.io.tmpdir for SparkR tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR sets the java.io.tmpdir for CRAN checks and also disables the hsperfdata for the JVM when running CRAN checks. Together this prevents files from being left behind in `/tmp` ## How was this patch tested? Tested manually on a clean EC2 machine Author: Shivaram Venkataraman Closes #19589 from shivaram/sparkr-tmpdir-clean. --- R/pkg/inst/tests/testthat/test_basic.R | 6 ++++-- R/pkg/tests/run-all.R | 9 +++++++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 8 +++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index de47162d5325f..823d26f12feee 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,7 +18,8 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) i <- 4 df <- createDataFrame(data.frame(dummy = 1:i)) @@ -49,7 +50,8 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, + sparkConfig = sparkRTestConfig) training <- suppressWarnings(createDataFrame(iris)) # gaussian family diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index a1834a220261d..a7f913e5fad11 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -36,8 +36,17 @@ invisible(lapply(sparkRWhitelistSQLDirs, sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" +sparkRTestConfig <- list() if (identical(Sys.getenv("NOT_CRAN"), "true")) { sparkRTestMaster <- "" +} else { + # Disable hsperfdata on CRAN + old_java_opt <- Sys.getenv("_JAVA_OPTIONS") + Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) + tmpDir <- tempdir() + tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) + sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index caeae72e37bbf..907bbb3d66018 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -36,6 +36,12 @@ opts_hooks$set(eval = function(options) { } options }) +r_tmp_dir <- tempdir() +tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, + spark.executor.extraJavaOptions = tmp_arg) +old_java_opt <- Sys.getenv("_JAVA_OPTIONS") +Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ``` ## Overview @@ -57,7 +63,7 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() -sparkR.session(master = "local[1]") +sparkR.session(master = "local[1]", sparkConfig = sparkSessionConfig, enableHiveSupport = FALSE) ``` ```{r, eval=FALSE} sparkR.session() From 188b47e68350731da775efccc2cda9c61610aa14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Oct 2017 11:50:22 +0900 Subject: [PATCH 141/712] [SPARK-22379][PYTHON] Reduce duplication setUpClass and tearDownClass in PySpark SQL tests ## What changes were proposed in this pull request? This PR propose to add `ReusedSQLTestCase` which deduplicate `setUpClass` and `tearDownClass` in `sql/tests.py`. ## How was this patch tested? Jenkins tests and manual tests. Author: hyukjinkwon Closes #19595 from HyukjinKwon/reduce-dupe. --- python/pyspark/sql/tests.py | 63 +++++++++++++------------------------ 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8ed37c9da98b8..483f39aeef66a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -179,6 +179,18 @@ def __init__(self, key, value): self.value = value +class ReusedSQLTestCase(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -214,21 +226,19 @@ def test_struct_field_type_name(self): self.assertRaises(TypeError, struct_field.typeName) -class SQLTests(ReusedPySparkTestCase): +class SQLTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_sqlcontext_reuses_sparksession(self): @@ -2623,17 +2633,7 @@ def test_hivecontext(self): self.assertTrue(os.path.exists(metastore_path)) -class SQLTests2(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTests2(ReusedSQLTestCase): # We can't include this test into SQLTests because we will stop class's SparkContext and cause # other tests failed. @@ -3082,12 +3082,12 @@ def __init__(self, **kwargs): @unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): +class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set @@ -3095,7 +3095,6 @@ def setUpClass(cls): os.environ["TZ"] = tz time.tzset() - cls.spark = SparkSession(cls.sc) cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ @@ -3116,8 +3115,7 @@ def tearDownClass(cls): if cls.tz_prev is not None: os.environ["TZ"] = cls.tz_prev time.tzset() - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() + ReusedSQLTestCase.tearDownClass() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3169,17 +3167,7 @@ def test_filtered_frame(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class VectorizedUDFTests(ReusedSQLTestCase): def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col @@ -3498,16 +3486,7 @@ def check_records_per_batch(x): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class GroupbyApplyTests(ReusedSQLTestCase): def assertFramesEqual(self, expected, result): msg = ("DataFrames are not equal: " + From 6eda55f728a6f2e265ae12a7e01dae88e4172715 Mon Sep 17 00:00:00 2001 From: tengpeng Date: Mon, 30 Oct 2017 07:24:55 +0000 Subject: [PATCH 142/712] Added more information to Imputer Often times we want to impute custom values other than 'NaN'. My addition helps people locate this function without reading the API. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: tengpeng Closes #19600 from tengpeng/patch-5. --- docs/ml-features.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 86a0e09997b8e..72643137d96b1 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1373,7 +1373,9 @@ for more details on the API. The `Imputer` transformer completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly -creates incorrect values for columns containing categorical features. +creates incorrect values for columns containing categorical features. Imputer can impute custom values +other than 'NaN' by `.setMissingValue(custom_value)`. For example, `.setMissingValue(0)` will impute +all occurrences of (0). **Note** all `null` values in the input columns are treated as missing, and so are also imputed. From 9f5c77ae32890b892b69b45a2833b9d6d6866aea Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 30 Oct 2017 07:45:54 +0000 Subject: [PATCH 143/712] [SPARK-21983][SQL] Fix Antlr 4.7 deprecation warnings ## What changes were proposed in this pull request? Fix three deprecation warnings introduced by move to ANTLR 4.7: * Use ParserRuleContext.addChild(TerminalNode) in preference to deprecated ParserRuleContext.addChild(Token) interface. * TokenStream.reset() is deprecated in favour of seek(0) * Replace use of deprecated ANTLRInputStream with stream returned by CharStreams.fromString() The last item changed the way we construct ANTLR's input stream (from direct instantiation to factory construction), so necessitated a change to how we override the LA() method to always return an upper-case char. The ANTLR object is now wrapped, rather than inherited-from. * Also fix incorrect usage of CharStream.getText() which expects the rhs of the supplied interval to be the last char to be returned, i.e. the interval is inclusive, and work around bug in ANTLR 4.7 where empty streams or intervals may cause getText() to throw an error. ## How was this patch tested? Ran all the sql tests. Confirmed that LA() override has coverage by breaking it, and noting that tests failed. Author: Henry Robinson Closes #19578 from henryr/spark-21983. --- .../sql/catalyst/parser/ParseDriver.scala | 39 +++++++++++++++---- .../sql/catalyst/parser/ParserUtils.scala | 4 +- .../catalyst/parser/ParserUtilsSuite.scala | 4 +- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 0d9ad218e48db..4c20f2368bded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.parser import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -80,7 +81,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { logDebug(s"Parsing command: $command") - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) @@ -99,7 +100,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { catch { case e: ParseCancellationException => // if we fail, parse with LL mode - tokenStream.reset() // rewind input stream + tokenStream.seek(0) // rewind input stream parser.reset() // Try Again. @@ -148,12 +149,33 @@ object CatalystSqlParser extends AbstractSqlParser { * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead * function and is purely used for matching lexical rules. This also means that the grammar will * only accept capitalized tokens in case it is run from other tools like antlrworks which do not - * have the ANTLRNoCaseStringStream implementation. + * have the UpperCaseCharStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { +private[parser] class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + override def LA(i: Int): Int = { - val la = super.LA(i) + val la = wrapped.LA(i) if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } @@ -244,11 +266,12 @@ case object PostProcessor extends SqlBaseBaseListener { val parent = ctx.getParent parent.removeLastChild() val token = ctx.getChild(0).getPayload.asInstanceOf[Token] - parent.addChild(f(new CommonToken( + val newToken = new CommonToken( new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), SqlBaseParser.IDENTIFIER, token.getChannel, token.getStartIndex + stripMargins, - token.getStopIndex - stripMargins))) + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9c1031e8033e7..9b127f91648e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -32,7 +32,7 @@ object ParserUtils { /** Get the command which created the token. */ def command(ctx: ParserRuleContext): String = { val stream = ctx.getStart.getInputStream - stream.getText(Interval.of(0, stream.size())) + stream.getText(Interval.of(0, stream.size() - 1)) } def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = { @@ -58,7 +58,7 @@ object ParserUtils { /** Get all the text which comes after the given token. */ def remainder(token: Token): String = { val stream = token.getInputStream - val interval = Interval.of(token.getStopIndex + 1, stream.size()) + val interval = Interval.of(token.getStopIndex + 1, stream.size() - 1) stream.getText(interval) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index d5748a4ff18f8..768030f0a9bc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.parser -import org.antlr.v4.runtime.{CommonTokenStream, ParserRuleContext} +import org.antlr.v4.runtime.{CharStreams, CommonTokenStream, ParserRuleContext} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -57,7 +57,7 @@ class ParserUtilsSuite extends SparkFunSuite { } private def buildContext[T](command: String)(toResult: SqlBaseParser => T): T = { - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) toResult(parser) From 9f02d7dc537b73988468b11337dbb14a8602f246 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Oct 2017 11:00:44 +0100 Subject: [PATCH 144/712] [SPARK-22385][SQL] MapObjects should not access list element by index ## What changes were proposed in this pull request? This issue was discovered and investigated by Ohad Raviv and Sean Owen in https://issues.apache.org/jira/browse/SPARK-21657. The input data of `MapObjects` may be a `List` which has O(n) complexity for accessing by index. When converting input data to catalyst array, `MapObjects` gets element by index in each loop, and results to bad performance. This PR fixes this issue by accessing elements via Iterator. ## How was this patch tested? using the test script in https://issues.apache.org/jira/browse/SPARK-21657 ``` val BASE = 100000000 val N = 100000 val df = sc.parallelize(List(("1234567890", (BASE to (BASE+N)).map(x => (x.toString, (x+1).toString, (x+2).toString, (x+3).toString)).toList ))).toDF("c1", "c_arr") spark.time(df.queryExecution.toRdd.foreach(_ => ())) ``` We can see 50x speed up. Author: Wenchen Fan Closes #19603 from cloud-fan/map-objects. --- .../expressions/objects/objects.scala | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9b28a18035b1c..6ae3490a3f863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -591,18 +591,43 @@ case class MapObjects private( case _ => inputData.dataType } - val (getLength, getLoopVar) = inputDataType match { + // `MapObjects` generates a while loop to traverse the elements of the input collection. We + // need to take care of Seq and List because they may have O(n) complexity for indexed accessing + // like `list.get(1)`. Here we use Iterator to traverse Seq and List. + val (getLength, prepareLoop, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();", + s"$it.next()" + ) case ObjectType(cls) if cls.isArray => - s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + ( + s"${genInputData.value}.length", + "", + s"${genInputData.value}[$loopIndex]" + ) case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"java.util.Iterator $it = ${genInputData.value}.iterator();", + s"$it.next()" + ) case ArrayType(et, _) => - s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + ( + s"${genInputData.value}.numElements()", + "", + ctx.getValue(genInputData.value, et, loopIndex) + ) case ObjectType(cls) if cls == classOf[Object] => - s"$seq == null ? $array.length : $seq.size()" -> - s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + val it = ctx.freshName("it") + ( + s"$seq == null ? $array.length : $seq.size()", + s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();", + s"$it == null ? $array[$loopIndex] : $it.next()" + ) } // Make a copy of the data if it's unsafe-backed @@ -676,6 +701,7 @@ case class MapObjects private( $initCollection int $loopIndex = 0; + $prepareLoop while ($loopIndex < $dataLength) { $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck From 3663764254615cef442d4af55c11808445e5b03a Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 30 Oct 2017 12:14:38 +0000 Subject: [PATCH 145/712] [WEB-UI] Add count in fair scheduler pool page ## What changes were proposed in this pull request? Add count in fair scheduler pool page. The purpose is to know the statistics clearly. For specific reasons, please refer to PR of https://github.com/apache/spark/pull/18525 fix before: ![1](https://user-images.githubusercontent.com/26266482/31641589-4b17b970-b318-11e7-97eb-f5a36db428f6.png) ![2](https://user-images.githubusercontent.com/26266482/31641643-97b6345a-b318-11e7-8c20-4b164ade228d.png) fix after: ![3](https://user-images.githubusercontent.com/26266482/31641688-e6ceacb6-b318-11e7-8204-6a816c581a29.png) ![4](https://user-images.githubusercontent.com/26266482/31641766-7310b0c0-b319-11e7-871d-a57f874f1e8b.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19507 from guoxiaolongzte/add_count_in_fair_scheduler_pool_page. --- .../src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index a30c13592947c..dc5b03c5269a9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -113,7 +113,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq +

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq } else { Seq.empty[Node] } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 819fe57e14b2d..4b8c7b203771d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -57,7 +57,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { var content =

    Summary

    ++ poolTable.toNodeSeq if (shouldShowActiveStages) { - content ++=

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq + content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq } UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) From 079a2609d7ad0a7dd2ec3eaa594e6ed8801a8008 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Oct 2017 17:53:06 +0100 Subject: [PATCH 146/712] [SPARK-17788][SPARK-21033][SQL] fix the potential OOM in UnsafeExternalSorter and ShuffleExternalSorter ## What changes were proposed in this pull request? In `UnsafeInMemorySorter`, one record may take 32 bytes: 1 `long` for pointer, 1 `long` for key-prefix, and another 2 `long`s as the temporary buffer for radix sort. In `UnsafeExternalSorter`, we set the `DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD` to be `1024 * 1024 * 1024 / 2`, and hoping the max size of point array to be 8 GB. However this is wrong, `1024 * 1024 * 1024 / 2 * 32` is actually 16 GB, and if we grow the point array before reach this limitation, we may hit the max-page-size error. Users may see exception like this on large dataset: ``` Caused by: java.lang.IllegalArgumentException: Cannot allocate a page with more than 17179869176 bytes at org.apache.spark.memory.TaskMemoryManager.allocatePage(TaskMemoryManager.java:241) at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:121) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:374) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:396) at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:94) ... ``` Setting `DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD` to a smaller number is not enough, users can still set the config to a big number and trigger the too large page size issue. This PR fixes it by explicitly handling the too large page size exception in the sorter and spill. This PR also change the type of `spark.shuffle.spill.numElementsForceSpillThreshold` to int, because it's only compared with `numRecords`, which is an int. This is an internal conf so we don't have a serious compatibility issue. ## How was this patch tested? TODO Author: Wenchen Fan Closes #18251 from cloud-fan/sort. --- .../apache/spark/memory/MemoryConsumer.java | 8 ++++++- .../spark/memory/TaskMemoryManager.java | 5 ++-- .../spark/memory/TooLargePageException.java | 24 +++++++++++++++++++ .../shuffle/sort/ShuffleExternalSorter.java | 16 ++++++++----- .../unsafe/sort/UnsafeExternalSorter.java | 21 +++++++++------- .../spark/internal/config/package.scala | 10 ++++++++ .../sort/UnsafeExternalSorterSuite.java | 11 +++++---- .../execution/UnsafeExternalRowSorter.java | 5 ++-- .../apache/spark/sql/internal/SQLConf.scala | 6 ++--- .../UnsafeFixedWidthAggregationMap.java | 6 ++--- .../sql/execution/UnsafeKVExternalSorter.java | 4 ++-- .../aggregate/ObjectAggregationIterator.scala | 6 ++--- .../aggregate/ObjectAggregationMap.scala | 5 ++-- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 3 ++- .../UnsafeKVExternalSorterSuite.scala | 4 ++-- 15 files changed, 92 insertions(+), 42 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/TooLargePageException.java diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 0efae16e9838c..2dff241900e82 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -83,7 +83,13 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Allocates a LongArray of `size`. + * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark + * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this + * `LongArray` is too large to fit in a single page. The caller side should take care of these + * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * + * @throws OutOfMemoryError + * @throws TooLargePageException */ public LongArray allocateArray(long size) { long required = size * 8L; diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 44b60c1e4e8c8..f6b5ea3c0ad26 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -270,13 +270,14 @@ public long pageSizeBytes() { * * Returns `null` if there was not enough memory to allocate the page. May return a page that * contains fewer bytes than requested, so callers should verify the size of returned pages. + * + * @throws TooLargePageException */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { assert(consumer != null); assert(consumer.getMode() == tungstenMemoryMode); if (size > MAXIMUM_PAGE_SIZE_BYTES) { - throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); + throw new TooLargePageException(size); } long acquired = acquireExecutionMemory(size, consumer); diff --git a/core/src/main/java/org/apache/spark/memory/TooLargePageException.java b/core/src/main/java/org/apache/spark/memory/TooLargePageException.java new file mode 100644 index 0000000000000..4abee77ff67b2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/TooLargePageException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +public class TooLargePageException extends RuntimeException { + TooLargePageException(long size) { + super("Cannot allocate a page of " + size + " bytes."); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index b4f46306f2827..e80f9734ecf7b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -31,8 +31,10 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; @@ -43,7 +45,6 @@ import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; -import org.apache.spark.internal.config.package$; /** * An external sorter that is specialized for sort-based shuffle. @@ -75,10 +76,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final ShuffleWriteMetrics writeMetrics; /** - * Force this sorter to spill when there are this many elements in memory. The default value is - * 1024 * 1024 * 1024, which allows the maximum size of the pointer array to be 8G. + * Force this sorter to spill when there are this many elements in memory. */ - private final long numElementsForSpillThreshold; + private final int numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -123,7 +123,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = - conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); + (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); @@ -325,7 +325,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() { + private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -333,6 +333,10 @@ private void growPointerArrayIfNecessary() { try { // could trigger spilling array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; } catch (OutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e749f7ba87c6e..8b8e15e3f78ed 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,6 +32,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; @@ -68,12 +69,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final int fileBufferSizeBytes; /** - * Force this sorter to spill when there are this many elements in memory. The default value is - * 1024 * 1024 * 1024 / 2 which allows the maximum size of the pointer array to be 8G. + * Force this sorter to spill when there are this many elements in memory. */ - public static final long DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD = 1024 * 1024 * 1024 / 2; + private final int numElementsForSpillThreshold; - private final long numElementsForSpillThreshold; /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -103,11 +102,11 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); + pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -123,7 +122,7 @@ public static UnsafeExternalSorter create( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, @@ -139,7 +138,7 @@ private UnsafeExternalSorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -338,7 +337,7 @@ public void cleanupResources() { * array and grows the array if additional space is required. If the required space cannot be * obtained, then the in-memory data will be spilled to disk. */ - private void growPointerArrayIfNecessary() { + private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); @@ -346,6 +345,10 @@ private void growPointerArrayIfNecessary() { try { // could trigger spilling array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; } catch (OutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6f0247b73070d..57e2da8353d6d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -475,4 +475,14 @@ package object config { .stringConf .toSequence .createOptional + + private[spark] val SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.numElementsForceSpillThreshold") + .internal() + .doc("The maximum number of elements in memory before forcing the shuffle sorter to spill. " + + "By default it's Integer.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .intConf + .createWithDefault(Integer.MAX_VALUE) } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index d0d0334add0bf..af4975c888d65 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.JavaSerializer; @@ -86,6 +87,9 @@ public int compare( private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); + private final int spillThreshold = + (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -159,7 +163,7 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); } @@ -383,7 +387,7 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -445,7 +449,7 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + spillThreshold, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time @@ -548,4 +552,3 @@ private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) } } } - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 12a123ee0bcff..6b002f0d3f8e8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -27,6 +27,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; @@ -89,8 +90,8 @@ public UnsafeExternalRowSorter( sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), pageSizeBytes, - SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + (int) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), canUseRadixSort ); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5203e8833fbbb..ede116e964a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -884,7 +884,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by window operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") @@ -899,7 +899,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by sort merge join operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") @@ -914,7 +914,7 @@ object SQLConf { .internal() .doc("Threshold for number of rows to be spilled by cartesian product operator") .intConf - .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 8fea46a58e857..c7c4c7b3e7715 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; @@ -29,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -238,8 +238,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti SparkEnv.get().blockManager(), SparkEnv.get().serializerManager(), map.getPageSizeBytes(), - SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + (int) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6aa52f1aae048..eb2fe82007af3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -57,7 +57,7 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - long numElementsForSpillThreshold) throws IOException { + int numElementsForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, numElementsForSpillThreshold, null); } @@ -68,7 +68,7 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - long numElementsForSpillThreshold, + int numElementsForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index c68dbc73f0447..43514f5271ac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -315,9 +315,7 @@ class SortBasedAggregator( SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong( - "spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index f2d4f6c6ebd5b..b5372bcca89dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import java.{util => ju} import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.config import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate} @@ -73,9 +74,7 @@ class ObjectAggregationMap() { SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong( - "spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), null ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index efe28afab08e5..59397dbcb1cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.internal.config import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.Benchmark @@ -231,6 +232,6 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X */ testAgainstRawUnsafeExternalSorter( - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get, 10 * 1000, 1 << 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 3d869c77e9608..359525fcd05a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,13 +22,13 @@ import java.util.Properties import scala.util.Random import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -125,7 +125,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD) + pageSize, config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => From 65338de5fbaf90774fb3f4c51321359d324ace56 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 30 Oct 2017 10:19:34 -0700 Subject: [PATCH 147/712] [SPARK-22396][SQL] Better Error Message for InsertIntoDir using Hive format without enabling Hive Support ## What changes were proposed in this pull request? When Hive support is not on, users can hit unresolved plan node when trying to call `INSERT OVERWRITE DIRECTORY` using Hive format. ``` "unresolved operator 'InsertIntoDir true, Storage(Location: /private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/spark-b4227606-9311-46a8-8c02-56355bf0e2bc, Serde Library: org.apache.hadoop.hive.ql.io.orc.OrcSerde, InputFormat: org.apache.hadoop.hive.ql.io.orc.OrcInputFormat, OutputFormat: org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat), hive, true;; ``` This PR is to issue a better error message. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #19608 from gatorsmile/hivesupportInsertOverwrite. --- .../spark/sql/execution/datasources/rules.scala | 3 +++ .../apache/spark/sql/sources/InsertSuite.scala | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7a2c85e8e01f6..60c430bcfece2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -404,6 +404,9 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) { plan.foreach { case CreateTable(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => throw new AnalysisException("Hive support is required to CREATE Hive TABLE (AS SELECT)") + case i: InsertIntoDir if DDLUtils.isHiveTable(i.provider) => + throw new AnalysisException( + "Hive support is required to INSERT OVERWRITE DIRECTORY with the Hive format") case _ => // OK } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 875b74551addb..8b7e2e5f45946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -408,6 +408,22 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } + test("Insert overwrite directory using Hive serde without turning on Hive support") { + withTempDir { dir => + val path = dir.toURI.getPath + val e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '$path' + |STORED AS orc + |SELECT 1, 2 + """.stripMargin) + }.getMessage + assert(e.contains( + "Hive support is required to INSERT OVERWRITE DIRECTORY with the Hive format")) + } + } + test("insert overwrite directory to data source not providing FileFormat") { withTempDir { dir => val path = dir.toURI.getPath From 44c4003155c1d243ffe0f73d5537b4c8b3f3b564 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 30 Oct 2017 10:21:05 -0700 Subject: [PATCH 148/712] [SPARK-22400][SQL] rename some APIs and classes to make their meaning clearer ## What changes were proposed in this pull request? Both `ReadSupport` and `ReadTask` have a method called `createReader`, but they create different things. This could cause some confusion for data source developers. The same issue exists between `WriteSupport` and `DataWriterFactory`, both of which have a method called `createWriter`. This PR renames the method of `ReadTask`/`DataWriterFactory` to `createDataReader`/`createDataWriter`. Besides, the name of `RowToInternalRowDataWriterFactory` is not correct, because it actually converts `InternalRow`s to `Row`s. It should be renamed `InternalRowDataWriterFactory`. ## How was this patch tested? Only renaming, should be covered by existing tests. Author: Zhenhua Wang Closes #19610 from wzhfy/rename. --- .../spark/sql/sources/v2/reader/DataReader.java | 4 ++-- .../spark/sql/sources/v2/reader/ReadTask.java | 2 +- .../spark/sql/sources/v2/writer/DataWriter.java | 6 +++--- .../sql/sources/v2/writer/DataWriterFactory.java | 2 +- .../execution/datasources/v2/DataSourceRDD.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 5 +++-- .../datasources/v2/WriteToDataSourceV2.scala | 14 +++++++------- .../sql/sources/v2/JavaAdvancedDataSourceV2.java | 2 +- .../sql/sources/v2/JavaSimpleDataSourceV2.java | 2 +- .../sql/sources/v2/JavaUnsafeRowDataSourceV2.java | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 8 +++++--- .../sql/sources/v2/SimpleWritableDataSource.scala | 9 ++++----- 12 files changed, 30 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 95e091569b614..52bb138673fc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -22,8 +22,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data - * for a RDD partition. + * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 01362df0978cb..44786db419a32 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -45,5 +45,5 @@ default String[] preferredLocations() { /** * Returns a data reader to do the actual reading work for this read task. */ - DataReader createReader(); + DataReader createDataReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 14261419af6f6..d84afbae32892 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -34,7 +34,7 @@ * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -64,7 +64,7 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which - * will be send back to driver side and pass to + * will be sent back to driver side and passed to * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index f812d102bda1a..fe56cc00d1c7a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -46,5 +46,5 @@ public interface DataWriterFactory extends Serializable { * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. */ - DataWriter createWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index b8fe5ac8e3d94..5f30be5ed4af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -39,7 +39,7 @@ class DataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createReader() + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[UnsafeRow] { private[this] var valuePrepared = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index addc12a3f0901..3f243dc44e043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -67,8 +67,9 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations - override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) + override def createDataReader: DataReader[UnsafeRow] = { + new RowToUnsafeDataReader( + rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 92c1e1f4a3383..b72d15ed15aed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -48,7 +48,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) override protected def doExecute(): RDD[InternalRow] = { val writeTask = writer match { case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } val rdd = query.execute() @@ -93,7 +93,7 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -111,18 +111,18 @@ object DataWritingSparkTask extends Logging { } } -class RowToInternalRowDataWriterFactory( +class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new RowToInternalRowDataWriter( - rowWriterFactory.createWriter(partitionId, attemptNumber), + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new InternalRowDataWriter( + rowWriterFactory.createDataWriter(partitionId, attemptNumber), RowEncoder.apply(schema).resolveAndBind()) } } -class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) +class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) extends DataWriter[InternalRow] { override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index da2c13f70c52a..1cfdc08217e6e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -100,7 +100,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { } @Override - public DataReader createReader() { + public DataReader createDataReader() { return new JavaAdvancedReadTask(start - 1, end, requiredSchema); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 08469f14c257a..2d458b7f7e906 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -58,7 +58,7 @@ static class JavaSimpleReadTask implements ReadTask, DataReader { } @Override - public DataReader createReader() { + public DataReader createDataReader() { return new JavaSimpleReadTask(start - 1, end); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index 9efe7c791a936..f6aa00869a681 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -58,7 +58,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createReader() { + public DataReader createDataReader() { return new JavaUnsafeRowReadTask(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 092702a1d5173..ab37e4984bd1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -167,7 +167,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { private var current = start - 1 - override def createReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) override def next(): Boolean = { current += 1 @@ -233,7 +233,9 @@ class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) private var current = start - 1 - override def createReader(): DataReader[Row] = new AdvancedReadTask(start, end, requiredSchema) + override def createDataReader(): DataReader[Row] = { + new AdvancedReadTask(start, end, requiredSchema) + } override def close(): Unit = {} @@ -273,7 +275,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 6fb60f4d848d7..cd7252eb2e3d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.text.SimpleDateFormat -import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} +import java.util.{Collections, List => JList, Optional} import scala.collection.JavaConverters._ @@ -157,7 +156,7 @@ class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createReader(): DataReader[Row] = { + override def createDataReader(): DataReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) @@ -185,7 +184,7 @@ class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -218,7 +217,7 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From ded3ed97337427477d33bd5fad76649e96f6e50a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 30 Oct 2017 21:44:24 -0700 Subject: [PATCH 149/712] [SPARK-22327][SPARKR][TEST] check for version warning ## What changes were proposed in this pull request? Will need to port to this to branch-1.6, -2.0, -2.1, -2.2 ## How was this patch tested? manually Jenkins, AppVeyor Author: Felix Cheung Closes #19549 from felixcheung/rcranversioncheck. --- R/run-tests.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/run-tests.sh b/R/run-tests.sh index 29764f48bd156..f38c86e3e6b1d 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -38,6 +38,7 @@ FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)" +HAS_PACKAGE_VERSION_WARN="$(grep -c "Insufficient package version" $CRAN_CHECK_LOG_FILE)" if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then cat $LOGFILE @@ -46,9 +47,10 @@ if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - # We have 2 existing NOTEs for new maintainer, attach() - # We have one more NOTE in Jenkins due to "No repository set" - if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then + # We have 2 NOTEs for RoxygenNote, attach(); and one in Jenkins only "No repository set" + # For non-latest version branches, one WARNING for package version + if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3) && + ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) ]]; then cat $CRAN_CHECK_LOG_FILE echo -en "\033[31m" # Red echo "Had CRAN check errors; see logs." From 1ff41d8693ade5bf34ee41a1140254488abfbbc7 Mon Sep 17 00:00:00 2001 From: "pj.fanning" Date: Tue, 31 Oct 2017 08:16:54 +0000 Subject: [PATCH 150/712] [SPARK-21708][BUILD] update some sbt plugins ## What changes were proposed in this pull request? These are just some straightforward upgrades to use the latest versions of some sbt plugins that also support sbt 1.0. The remaining sbt plugins that need upgrading will require bigger changes. ## How was this patch tested? Tested sbt use manually. Author: pj.fanning Closes #19609 from pjfanning/SPARK-21708. --- project/plugins.sbt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 3c5442b04b8e4..96bdb9067ae59 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,11 +1,9 @@ // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -// sbt 1.0.0 support: https://github.com/typesafehub/sbteclipse/issues/343 -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.1.0") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3") -// sbt 1.0.0 support: https://github.com/jrudolph/sbt-dependency-graph/issues/134 -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") @@ -20,8 +18,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") // need to make changes to uptake sbt 1.0 support in "com.cavorite" % "sbt-avro-1-7" % "1.1.2" addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") -// sbt 1.0.0 support: https://github.com/spray/sbt-revolver/issues/62 -addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") libraryDependencies += "org.ow2.asm" % "asm" % "5.1" From aa6db57e39d4931658089d9237dbf2a29acfe5ed Mon Sep 17 00:00:00 2001 From: bomeng Date: Tue, 31 Oct 2017 08:20:23 +0000 Subject: [PATCH 151/712] [SPARK-22399][ML] update the location of reference paper ## What changes were proposed in this pull request? Update the url of reference paper. ## How was this patch tested? It is comments, so nothing tested. Author: bomeng Closes #19614 from bomeng/22399. --- docs/mllib-clustering.md | 2 +- .../examples/mllib/PowerIterationClusteringExample.scala | 3 ++- .../spark/mllib/clustering/PowerIterationClustering.scala | 6 +++--- python/pyspark/mllib/clustering.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8990e95796b67..df2be92d860e4 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -134,7 +134,7 @@ Refer to the [`GaussianMixture` Python docs](api/python/pyspark.mllib.html#pyspa Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a graph given pairwise similarities as edge properties, -described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). +described in [Lin and Cohen, Power Iteration Clustering](http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. `spark.mllib` includes an implementation of PIC using GraphX as its backend. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 986496c0d9435..65603252c4384 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -28,7 +28,8 @@ import org.apache.spark.mllib.clustering.PowerIterationClustering import org.apache.spark.rdd.RDD /** - * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * An example Power Iteration Clustering app. + * http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf * Takes an input of K concentric circles and the number of points in the innermost circle. * The output should be K clusters - each cluster containing precisely the points associated * with each of the input circles. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index b2437b845f826..9444f29a91ed8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -103,9 +103,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode /** * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - * Lin and Cohen. From the abstract: PIC finds - * a very low-dimensional embedding of a dataset using truncated power iteration on a normalized - * pair-wise similarity matrix of the data. + * Lin and Cohen. + * From the abstract: PIC finds a very low-dimensional embedding of a dataset using + * truncated power iteration on a normalized pair-wise similarity matrix of the data. * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 91123ace3387e..bb687a7da6ffd 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -636,7 +636,7 @@ def load(cls, sc, path): class PowerIterationClustering(object): """ Power Iteration Clustering (PIC), a scalable graph clustering algorithm - developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + developed by [[http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf Lin and Cohen]]. From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. From 59589bc6545b6665432febfa9ee4891a96d119c4 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 31 Oct 2017 11:13:48 +0100 Subject: [PATCH 152/712] [SPARK-22310][SQL] Refactor join estimation to incorporate estimation logic for different kinds of statistics ## What changes were proposed in this pull request? The current join estimation logic is only based on basic column statistics (such as ndv, etc). If we want to add estimation for other kinds of statistics (such as histograms), it's not easy to incorporate into the current algorithm: 1. When we have multiple pairs of join keys, the current algorithm computes cardinality in a single formula. But if different join keys have different kinds of stats, the computation logic for each pair of join keys become different, so the previous formula does not apply. 2. Currently it computes cardinality and updates join keys' column stats separately. It's better to do these two steps together, since both computation and update logic are different for different kinds of stats. ## How was this patch tested? Only refactor, covered by existing tests. Author: Zhenhua Wang Closes #19531 from wzhfy/join_est_refactor. --- .../BasicStatsPlanVisitor.scala | 4 +- .../statsEstimation/JoinEstimation.scala | 172 +++++++++--------- 2 files changed, 85 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 4cff72d45a400..ca0775a2e8408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.LongType /** * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. @@ -54,7 +52,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { - JoinEstimation.estimate(p).getOrElse(fallback(p)) + JoinEstimation(p).estimate.getOrElse(fallback(p)) } override def visitLocalLimit(p: LocalLimit): Statistics = fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index dcbe36da91dfc..b073108c26ee5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -28,60 +28,58 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -object JoinEstimation extends Logging { +case class JoinEstimation(join: Join) extends Logging { + + private val leftStats = join.left.stats + private val rightStats = join.right.stats + /** * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(join: Join): Option[Statistics] = { + def estimate: Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(join).doEstimate() + estimateInnerOuterJoin() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(join).doEstimate() + estimateLeftSemiAntiJoin() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None } } -} - -case class InnerOuterEstimation(join: Join) extends Logging { - - private val leftStats = join.left.stats - private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ - def doEstimate(): Option[Statistics] = join match { + private def estimateInnerOuterJoin(): Option[Statistics] = join match { case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) - val selectivity = joinSelectivity(joinKeyPairs) + val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) // 2. Estimate the number of output rows val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - val innerJoinedRows = ceil(BigDecimal(leftRows * rightRows) * selectivity) // Make sure outputRows won't be too small based on join type. val outputRows = joinType match { case LeftOuter => // All rows from left side should be in the result. - leftRows.max(innerJoinedRows) + leftRows.max(numInnerJoinedRows) case RightOuter => // All rows from right side should be in the result. - rightRows.max(innerJoinedRows) + rightRows.max(numInnerJoinedRows) case FullOuter => // T(A FOJ B) = T(A LOJ B) + T(A ROJ B) - T(A IJ B) - leftRows.max(innerJoinedRows) + rightRows.max(innerJoinedRows) - innerJoinedRows + leftRows.max(numInnerJoinedRows) + rightRows.max(numInnerJoinedRows) - numInnerJoinedRows case _ => + assert(joinType == Inner || joinType == Cross) // Don't change for inner or cross join - innerJoinedRows + numInnerJoinedRows } // 3. Update statistics based on the output of join @@ -93,7 +91,7 @@ case class InnerOuterEstimation(join: Join) extends Logging { val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { // The output is empty, we don't need to keep column stats. Nil - } else if (selectivity == 0) { + } else if (numInnerJoinedRows == 0) { joinType match { // For outer joins, if the join selectivity is 0, the number of output rows is the // same as that of the outer side. And column stats of join keys from the outer side @@ -113,26 +111,28 @@ case class InnerOuterEstimation(join: Join) extends Logging { val oriColStat = inputAttrStats(a) (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) } - case _ => Nil + case _ => + assert(joinType == Inner || joinType == Cross) + Nil } - } else if (selectivity == 1) { + } else if (numInnerJoinedRows == leftRows * rightRows) { // Cartesian product, just propagate the original column stats inputAttrStats.toSeq } else { - val joinKeyStats = getIntersectedStats(joinKeyPairs) join.joinType match { // For outer joins, don't update column stats from the outer side. case LeftOuter => fromLeft.map(a => (a, inputAttrStats(a))) ++ - updateAttrStats(outputRows, fromRight, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, fromRight, inputAttrStats, keyStatsAfterJoin) case RightOuter => - updateAttrStats(outputRows, fromLeft, inputAttrStats, joinKeyStats) ++ + updateOutputStats(outputRows, fromLeft, inputAttrStats, keyStatsAfterJoin) ++ fromRight.map(a => (a, inputAttrStats(a))) case FullOuter => inputAttrStats.toSeq case _ => + assert(joinType == Inner || joinType == Cross) // Update column stats from both sides for inner or cross join. - updateAttrStats(outputRows, attributesWithStat, inputAttrStats, joinKeyStats) + updateOutputStats(outputRows, attributesWithStat, inputAttrStats, keyStatsAfterJoin) } } @@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). + * + * @param keyPairs pairs of join keys + * + * @return join cardinality, and column stats for join keys after the join */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) + : (BigInt, AttributeMap[ColumnStat]) = { + // If there's no column stats available for join keys, estimate as cartesian product. + var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { - val (leftKey, rightKey) = joinKeyPairs(i) + while(i < keyPairs.length && joinCard != 0) { + val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) + // Return cardinality estimated from the most selective join keys. + if (card < joinCard) joinCard = card } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + joinCard = 0 } i += 1 } + (joinCard, AttributeMap(keyStatsAfterJoin.toSeq)) + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) - } + /** Returns join cardinality and the column stat for this pair of join keys. */ + private def computeByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, ColumnStat) = { + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + // Compute cardinality by the basic formula. + val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) + + // Get the intersected column stat. + val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + + (ceil(card), newStats) } /** * Propagate or update column stats for output attributes. */ - private def updateAttrStats( + private def updateOutputStats( outputRows: BigInt, - attributes: Seq[Attribute], + output: Seq[Attribute], oldAttrStats: AttributeMap[ColumnStat], - joinKeyStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + keyStatsAfterJoin: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() val leftRows = leftStats.rowCount.get val rightRows = rightStats.rowCount.get - attributes.foreach { a => + output.foreach { a => // check if this attribute is a join key - if (joinKeyStats.contains(a)) { - outputAttrStats += a -> joinKeyStats(a) + if (keyStatsAfterJoin.contains(a)) { + outputAttrStats += a -> keyStatsAfterJoin(a) } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount @@ -231,34 +257,6 @@ case class InnerOuterEstimation(join: Join) extends Logging { outputAttrStats } - /** Get intersected column stats for join keys. */ - private def getIntersectedStats(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) - : AttributeMap[ColumnStat] = { - - val intersectedStats = new mutable.HashMap[Attribute, ColumnStat]() - joinKeyPairs.foreach { case (leftKey, rightKey) => - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - // When we reach here, join selectivity is not zero, so each pair of join keys should be - // intersected. - assert(ValueInterval.isIntersected(lInterval, rInterval)) - - // Update intersected column stats - assert(leftKey.dataType.sameType(rightKey.dataType)) - val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) - val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) - - intersectedStats.put(leftKey, newStats) - intersectedStats.put(rightKey, newStats) - } - AttributeMap(intersectedStats.toSeq) - } - private def extractJoinKeysWithColStats( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[(AttributeReference, AttributeReference)] = { @@ -270,10 +268,8 @@ case class InnerOuterEstimation(join: Join) extends Logging { if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } -} -case class LeftSemiAntiEstimation(join: Join) { - def doEstimate(): Option[Statistics] = { + private def estimateLeftSemiAntiJoin(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. From 4d9ebf3835dde1abbf9cff29a55675d9f4227620 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 31 Oct 2017 11:35:32 +0100 Subject: [PATCH 153/712] [SPARK-19611][SQL][FOLLOWUP] set dataSchema correctly in HiveMetastoreCatalog.convertToLogicalRelation ## What changes were proposed in this pull request? We made a mistake in https://github.com/apache/spark/pull/16944 . In `HiveMetastoreCatalog#inferIfNeeded` we infer the data schema, merge with full schema, and return the new full schema. At caller side we treat the full schema as data schema and set it to `HadoopFsRelation`. This doesn't cause any problem because both parquet and orc can work with a wrong data schema that has extra columns, but it's better to fix this mistake. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19615 from cloud-fan/infer. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f0f2c493498b3..5ac65973e70e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -164,13 +164,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - val (dataSchema, updatedTable) = - inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val updatedTable = inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = dataSchema, + dataSchema = updatedTable.dataSchema, bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) @@ -191,13 +190,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { - val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat) + val updatedTable = inferIfNeeded(relation, options, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Option(dataSchema), + userSpecifiedSchema = Option(updatedTable.dataSchema), bucketSpec = None, options = options, className = fileType).resolveRelation(), @@ -224,7 +223,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log relation: HiveTableRelation, options: Map[String, String], fileFormat: FileFormat, - fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { + fileIndexOpt: Option[FileIndex] = None): CatalogTable = { val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase val tableName = relation.tableMeta.identifier.unquotedString @@ -241,21 +240,22 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log sparkSession, options, fileIndex.listFiles(Nil, Nil).flatMap(_.files)) - .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) + .map(mergeWithMetastoreSchema(relation.tableMeta.dataSchema, _)) inferredSchema match { - case Some(schema) => + case Some(dataSchema) => + val schema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) if (inferenceMode == INFER_AND_SAVE) { updateCatalogSchema(relation.tableMeta.identifier, schema) } - (schema, relation.tableMeta.copy(schema = schema)) + relation.tableMeta.copy(schema = schema) case None => logWarning(s"Unable to infer schema for table $tableName from file format " + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } else { - (relation.tableMeta.schema, relation.tableMeta) + relation.tableMeta } } From 7986cc09b1b2100fc061d0aea8aa2e1e1b162c75 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Tue, 31 Oct 2017 09:49:58 -0700 Subject: [PATCH 154/712] [SPARK-11334][CORE] Fix bug in Executor allocation manager in running tasks calculation ## What changes were proposed in this pull request? We often see the issue of Spark jobs stuck because the Executor Allocation Manager does not ask for any executor even if there are pending tasks in case dynamic allocation is turned on. Looking at the logic in Executor Allocation Manager, which calculates the running tasks, it can happen that the calculation will be wrong and the number of running tasks can become negative. ## How was this patch tested? Added unit test Author: Sital Kedia Closes #19580 from sitalkedia/skedia/fix_stuck_job. --- .../spark/ExecutorAllocationManager.scala | 29 ++++++++++++------- .../ExecutorAllocationManagerSuite.scala | 22 ++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 119b426a9af34..5bc2d9ef1b949 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -267,6 +267,10 @@ private[spark] class ExecutorAllocationManager( (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor } + private def totalRunningTasks(): Int = synchronized { + listener.totalRunningTasks + } + /** * This is called at a fixed interval to regulate the number of pending executor requests * and number of executors running. @@ -602,12 +606,11 @@ private[spark] class ExecutorAllocationManager( private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] + // Number of running tasks per stage including speculative tasks. + // Should be 0 when no stages are active. + private val stageIdToNumRunningTask = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] - // Number of tasks currently running on the cluster including speculative tasks. - // Should be 0 when no stages are active. - private var numRunningTasks: Int = _ - // Number of speculative tasks to be scheduled in each stage private val stageIdToNumSpeculativeTasks = new mutable.HashMap[Int, Int] // The speculative tasks started in each stage @@ -625,6 +628,7 @@ private[spark] class ExecutorAllocationManager( val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks + stageIdToNumRunningTask(stageId) = 0 allocationManager.onSchedulerBacklogged() // Compute the number of tasks requested by the stage on each host @@ -651,6 +655,7 @@ private[spark] class ExecutorAllocationManager( val stageId = stageCompleted.stageInfo.stageId allocationManager.synchronized { stageIdToNumTasks -= stageId + stageIdToNumRunningTask -= stageId stageIdToNumSpeculativeTasks -= stageId stageIdToTaskIndices -= stageId stageIdToSpeculativeTaskIndices -= stageId @@ -663,10 +668,6 @@ private[spark] class ExecutorAllocationManager( // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() - if (numRunningTasks != 0) { - logWarning("No stages are running, but numRunningTasks != 0") - numRunningTasks = 0 - } } } } @@ -678,7 +679,9 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { - numRunningTasks += 1 + if (stageIdToNumRunningTask.contains(stageId)) { + stageIdToNumRunningTask(stageId) += 1 + } // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -709,7 +712,9 @@ private[spark] class ExecutorAllocationManager( val taskIndex = taskEnd.taskInfo.index val stageId = taskEnd.stageId allocationManager.synchronized { - numRunningTasks -= 1 + if (stageIdToNumRunningTask.contains(stageId)) { + stageIdToNumRunningTask(stageId) -= 1 + } // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId @@ -787,7 +792,9 @@ private[spark] class ExecutorAllocationManager( /** * The number of tasks currently running across all stages. */ - def totalRunningTasks(): Int = numRunningTasks + def totalRunningTasks(): Int = { + stageIdToNumRunningTask.values.sum + } /** * Return true if an executor is not currently running a task, and false otherwise. diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a91e09b7cb69f..90b7ec4384abd 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -227,6 +227,23 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + test("ignore task end events from completed stages") { + sc = createSparkContext(0, 10, 0) + val manager = sc.executorAllocationManager.get + val stage = createStageInfo(0, 5) + post(sc.listenerBus, SparkListenerStageSubmitted(stage)) + val taskInfo1 = createTaskInfo(0, 0, "executor-1") + val taskInfo2 = createTaskInfo(1, 1, "executor-1") + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo1)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo2)) + + post(sc.listenerBus, SparkListenerStageCompleted(stage)) + + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, taskInfo1, null)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, taskInfo2, null)) + assert(totalRunningTasks(manager) === 0) + } + test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get @@ -1107,6 +1124,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit]('onSpeculativeTaskSubmitted) + private val _totalRunningTasks = PrivateMethod[Int]('totalRunningTasks) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -1190,6 +1208,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _localityAwareTasks() } + private def totalRunningTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _totalRunningTasks() + } + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { manager invokePrivate _hostToLocalTaskCount() } From 73231860baaa40f6001db347e5dcb6b5bb65e032 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 31 Oct 2017 11:53:50 -0700 Subject: [PATCH 155/712] [SPARK-22305] Write HDFSBackedStateStoreProvider.loadMap non-recursively ## What changes were proposed in this pull request? Write HDFSBackedStateStoreProvider.loadMap non-recursively. This prevents stack overflow if too many deltas stack up in a low memory environment. ## How was this patch tested? existing unit tests for functional equivalence, new unit test to check for stack overflow Author: Jose Torres Closes #19611 from joseph-torres/SPARK-22305. --- .../state/HDFSBackedStateStoreProvider.scala | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 36d6569a4187a..3f5002a4e6937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -297,17 +297,44 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { - if (version <= 0) return new MapType - synchronized { loadedMaps.get(version) }.getOrElse { - val mapFromFile = readSnapshotFile(version).getOrElse { - val prevMap = loadMap(version - 1) - val newMap = new MapType(prevMap) - updateFromDeltaFile(version, newMap) - newMap + + // Shortcut if the map for this version is already there to avoid a redundant put. + val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + if (loadedCurrentVersionMap.isDefined) { + return loadedCurrentVersionMap.get + } + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) } - loadedMaps.put(version, mapFromFile) - mapFromFile } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) + } + + synchronized { loadedMaps.put(version, resultMap) } + resultMap } private def writeUpdateToDeltaFile( From 556b5d21512b17027a6e451c6a82fb428940e95a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 1 Nov 2017 08:45:11 +0000 Subject: [PATCH 156/712] [SPARK-5484][FOLLOWUP] PeriodicRDDCheckpointer doc cleanup ## What changes were proposed in this pull request? PeriodicRDDCheckpointer was already moved out of mllib in Spark-5484 ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19618 from zhengruifeng/checkpointer_doc. --- .../org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index facbb830a60d8..5e181a9822534 100644 --- a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -73,8 +73,6 @@ import org.apache.spark.util.PeriodicCheckpointer * * @param checkpointInterval RDDs will be checkpointed at this interval * @tparam T RDD element type - * - * TODO: Move this out of MLlib? */ private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, From 96798d14f07208796fa0a90af0ab369879bacd6c Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 1 Nov 2017 18:07:39 +0800 Subject: [PATCH 157/712] [SPARK-22172][CORE] Worker hangs when the external shuffle service port is already in use ## What changes were proposed in this pull request? Handling the NonFatal exceptions while starting the external shuffle service, if there are any NonFatal exceptions it logs and continues without the external shuffle service. ## How was this patch tested? I verified it manually, it logs the exception and continues to serve without external shuffle service when BindException occurs. Author: Devaraj K Closes #19396 from devaraj-kavali/SPARK-22172. --- .../org/apache/spark/deploy/worker/Worker.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ed5fa4b839cd4..3962d422f81d3 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -199,7 +199,7 @@ private[deploy] class Worker( logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - shuffleService.startIfEnabled() + startExternalShuffleService() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -367,6 +367,16 @@ private[deploy] class Worker( } } + private def startExternalShuffleService() { + try { + shuffleService.startIfEnabled() + } catch { + case e: Exception => + logError("Failed to start external shuffle service", e) + System.exit(1) + } + } + private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { masterEndpoint.send(RegisterWorker( workerId, From 07f390a27d7b793291c352a643d4bbd5f47294a6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Nov 2017 13:09:35 +0100 Subject: [PATCH 158/712] [SPARK-22347][PYSPARK][DOC] Add document to notice users for using udfs with conditional expressions ## What changes were proposed in this pull request? Under the current execution mode of Python UDFs, we don't well support Python UDFs as branch values or else value in CaseWhen expression. Since to fix it might need the change not small (e.g., #19592) and this issue has simpler workaround. We should just notice users in the document about this. ## How was this patch tested? Only document change. Author: Liang-Chi Hsieh Closes #19617 from viirya/SPARK-22347-3. --- python/pyspark/sql/functions.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d40368c9cd6e..39815497f3956 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2185,6 +2185,13 @@ def udf(f=None, returnType=StringType()): duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. + .. note:: The user-defined functions do not support conditional execution by using them with + SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no + matter the conditions are met or not. So the output is correct if the functions can be + correctly run on all rows without failure. If the functions can cause runtime failure on the + rows that do not satisfy the conditions, the suggested workaround is to incorporate the + condition logic into the functions. + :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object @@ -2278,6 +2285,13 @@ def pandas_udf(f=None, returnType=StringType()): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. + + .. note:: The user-defined functions do not support conditional execution by using them with + SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no + matter the conditions are met or not. So the output is correct if the functions can be + correctly run on all rows without failure. If the functions can cause runtime failure on the + rows that do not satisfy the conditions, the suggested workaround is to incorporate the + condition logic into the functions. """ return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) From 444bce1c98c45147fe63e2132e9743a0c5e49598 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Wed, 1 Nov 2017 14:54:08 +0100 Subject: [PATCH 159/712] [SPARK-19112][CORE] Support for ZStandard codec ## What changes were proposed in this pull request? Using zstd compression for Spark jobs spilling 100s of TBs of data, we could reduce the amount of data written to disk by as much as 50%. This translates to significant latency gain because of reduced disk io operations. There is a degradation CPU time by 2 - 5% because of zstd compression overhead, but for jobs which are bottlenecked by disk IO, this hit can be taken. ## Benchmark Please note that this benchmark is using real world compute heavy production workload spilling TBs of data to disk | | zstd performance as compred to LZ4 | | ------------- | -----:| | spill/shuffle bytes | -48% | | cpu time | + 3% | | cpu reservation time | -40%| | latency | -40% | ## How was this patch tested? Tested by running few jobs spilling large amount of data on the cluster and amount of intermediate data written to disk reduced by as much as 50%. Author: Sital Kedia Closes #18805 from sitalkedia/skedia/upstream_zstd. --- LICENSE | 2 ++ core/pom.xml | 4 +++ .../apache/spark/io/CompressionCodec.scala | 36 +++++++++++++++++-- .../spark/io/CompressionCodecSuite.scala | 18 ++++++++++ dev/deps/spark-deps-hadoop-2.6 | 1 + dev/deps/spark-deps-hadoop-2.7 | 1 + docs/configuration.md | 20 ++++++++++- licenses/LICENSE-zstd-jni.txt | 26 ++++++++++++++ licenses/LICENSE-zstd.txt | 30 ++++++++++++++++ pom.xml | 5 +++ 10 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 licenses/LICENSE-zstd-jni.txt create mode 100644 licenses/LICENSE-zstd.txt diff --git a/LICENSE b/LICENSE index 39fe0dc462385..c2b0d72663b55 100644 --- a/LICENSE +++ b/LICENSE @@ -269,6 +269,8 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) + (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) + (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) ======================================================================== MIT licenses diff --git a/core/pom.xml b/core/pom.xml index 54f7a34a6c37e..fa138d3e7a4e0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -198,6 +198,10 @@ org.lz4 lz4-java
    + + com.github.luben + zstd-jni + org.roaringbitmap RoaringBitmap diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 27f2e429395db..7722db56ee297 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -20,6 +20,7 @@ package org.apache.spark.io import java.io._ import java.util.Locale +import com.github.luben.zstd.{ZstdInputStream, ZstdOutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} @@ -50,13 +51,14 @@ private[spark] object CompressionCodec { private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] - || codec.isInstanceOf[LZ4CompressionCodec]) + || codec.isInstanceOf[LZ4CompressionCodec] || codec.isInstanceOf[ZStdCompressionCodec]) } private val shortCompressionCodecNames = Map( "lz4" -> classOf[LZ4CompressionCodec].getName, "lzf" -> classOf[LZFCompressionCodec].getName, - "snappy" -> classOf[SnappyCompressionCodec].getName) + "snappy" -> classOf[SnappyCompressionCodec].getName, + "zstd" -> classOf[ZStdCompressionCodec].getName) def getCodecName(conf: SparkConf): String = { conf.get(configKey, DEFAULT_COMPRESSION_CODEC) @@ -219,3 +221,33 @@ private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends Ou } } } + +/** + * :: DeveloperApi :: + * ZStandard implementation of [[org.apache.spark.io.CompressionCodec]]. For more + * details see - http://facebook.github.io/zstd/ + * + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. + */ +@DeveloperApi +class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { + + private val bufferSize = conf.getSizeAsBytes("spark.io.compression.zstd.bufferSize", "32k").toInt + // Default compression level for zstd compression to 1 because it is + // fastest of all with reasonably high compression ratio. + private val level = conf.getInt("spark.io.compression.zstd.level", 1) + + override def compressedOutputStream(s: OutputStream): OutputStream = { + // Wrap the zstd output stream in a buffered output stream, so that we can + // avoid overhead excessive of JNI call while trying to compress small amount of data. + new BufferedOutputStream(new ZstdOutputStream(s, level), bufferSize) + } + + override def compressedInputStream(s: InputStream): InputStream = { + // Wrap the zstd input stream in a buffered input stream so that we can + // avoid overhead excessive of JNI call while trying to uncompress small amount of data. + new BufferedInputStream(new ZstdInputStream(s), bufferSize) + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 9e9c2b0165e13..7b40e3e58216d 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -104,6 +104,24 @@ class CompressionCodecSuite extends SparkFunSuite { testConcatenationOfSerializedStreams(codec) } + test("zstd compression codec") { + val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testCodec(codec) + } + + test("zstd compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "zstd") + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testCodec(codec) + } + + test("zstd supports concatenation of serialized zstd") { + val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testConcatenationOfSerializedStreams(codec) + } + test("bad compression codec") { intercept[IllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 6e2fc63d67108..21c8a75796387 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -189,3 +189,4 @@ xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar zookeeper-3.4.6.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index c2bbc253d723a..7173426c7bf74 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -190,3 +190,4 @@ xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar zookeeper-3.4.6.jar +zstd-jni-1.3.2-2.jar diff --git a/docs/configuration.md b/docs/configuration.md index d3c358bb74173..9b9583d9165ef 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -889,7 +889,8 @@ Apart from these, the following properties are also available, and may be useful e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, - and org.apache.spark.io.SnappyCompressionCodec. + org.apache.spark.io.SnappyCompressionCodec, + and org.apache.spark.io.ZstdCompressionCodec. @@ -908,6 +909,23 @@ Apart from these, the following properties are also available, and may be useful is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. + + spark.io.compression.zstd.level + 1 + + Compression level for Zstd compression codec. Increasing the compression level will result in better + compression at the expense of more CPU and memory. + + + + spark.io.compression.zstd.bufferSize + 32k + + Buffer size used in Zstd compression, in the case when Zstd compression codec + is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it + might increase the compression cost because of excessive JNI call overhead. + + spark.kryo.classesToRegister (none) diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses/LICENSE-zstd-jni.txt new file mode 100644 index 0000000000000..32c6bbdd980d6 --- /dev/null +++ b/licenses/LICENSE-zstd-jni.txt @@ -0,0 +1,26 @@ +Zstd-jni: JNI bindings to Zstd Library + +Copyright (c) 2015-2016, Luben Karavelov/ All rights reserved. + +BSD License + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE-zstd.txt b/licenses/LICENSE-zstd.txt new file mode 100644 index 0000000000000..a793a80289256 --- /dev/null +++ b/licenses/LICENSE-zstd.txt @@ -0,0 +1,30 @@ +BSD License + +For Zstandard software + +Copyright (c) 2016-present, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pom.xml b/pom.xml index 2d59f06811a82..652aed4d12f7a 100644 --- a/pom.xml +++ b/pom.xml @@ -537,6 +537,11 @@ lz4-java 1.4.0 + + com.github.luben + zstd-jni + 1.3.2-2 + com.clearspring.analytics stream From 1ffe03d9e87fb784cc8a0bae232c81c7b14deac9 Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Wed, 1 Nov 2017 15:40:25 +0100 Subject: [PATCH 160/712] [SPARK-22190][CORE] Add Spark executor task metrics to Dropwizard metrics ## What changes were proposed in this pull request? This proposed patch is about making Spark executor task metrics available as Dropwizard metrics. This is intended to be of aid in monitoring Spark jobs and when drilling down on performance troubleshooting issues. ## How was this patch tested? Manually tested on a Spark cluster (see JIRA for an example screenshot). Author: LucaCanali Closes #19426 from LucaCanali/SparkTaskMetricsDropWizard. --- .../org/apache/spark/executor/Executor.scala | 41 ++++++++++++++++ .../spark/executor/ExecutorSource.scala | 48 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2ecbb749d1fb7..e3e555eaa0277 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -406,6 +406,47 @@ private[spark] class Executor( task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization) + // Expose task metrics using the Dropwizard metrics system. + // Update task metrics counters + executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime) + executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime) + executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime) + executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime) + executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime) + executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime) + executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME + .inc(task.metrics.shuffleReadMetrics.fetchWaitTime) + executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime) + executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.totalBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.remoteBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK + .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk) + executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.localBytesRead) + executorSource.METRIC_SHUFFLE_RECORDS_READ + .inc(task.metrics.shuffleReadMetrics.recordsRead) + executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched) + executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.localBlocksFetched) + executorSource.METRIC_SHUFFLE_BYTES_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.bytesWritten) + executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.recordsWritten) + executorSource.METRIC_INPUT_BYTES_READ + .inc(task.metrics.inputMetrics.bytesRead) + executorSource.METRIC_INPUT_RECORDS_READ + .inc(task.metrics.inputMetrics.recordsRead) + executorSource.METRIC_OUTPUT_BYTES_WRITTEN + .inc(task.metrics.outputMetrics.bytesWritten) + executorSource.METRIC_OUTPUT_RECORDS_WRITTEN + .inc(task.metrics.inputMetrics.recordsRead) + executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize) + executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled) + executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled) + // Note: accumulator updates must be collected after TaskMetrics is updated val accumUpdates = task.collectAccumulatorUpdates() // TODO: do not serialize value twice diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index d16f4a1fc4e3b..669ce63325d0e 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -72,4 +72,52 @@ class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends registerFileSystemStat(scheme, "largeRead_ops", _.getLargeReadOps(), 0) registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0) } + + // Expose executor task metrics using the Dropwizard metrics system. + // The list is taken from TaskMetrics.scala + val METRIC_CPU_TIME = metricRegistry.counter(MetricRegistry.name("cpuTime")) + val METRIC_RUN_TIME = metricRegistry.counter(MetricRegistry.name("runTime")) + val METRIC_JVM_GC_TIME = metricRegistry.counter(MetricRegistry.name("jvmGCTime")) + val METRIC_DESERIALIZE_TIME = + metricRegistry.counter(MetricRegistry.name("deserializeTime")) + val METRIC_DESERIALIZE_CPU_TIME = + metricRegistry.counter(MetricRegistry.name("deserializeCpuTime")) + val METRIC_RESULT_SERIALIZE_TIME = + metricRegistry.counter(MetricRegistry.name("resultSerializationTime")) + val METRIC_SHUFFLE_FETCH_WAIT_TIME = + metricRegistry.counter(MetricRegistry.name("shuffleFetchWaitTime")) + val METRIC_SHUFFLE_WRITE_TIME = + metricRegistry.counter(MetricRegistry.name("shuffleWriteTime")) + val METRIC_SHUFFLE_TOTAL_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleTotalBytesRead")) + val METRIC_SHUFFLE_REMOTE_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBytesRead")) + val METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBytesReadToDisk")) + val METRIC_SHUFFLE_LOCAL_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("shuffleLocalBytesRead")) + val METRIC_SHUFFLE_RECORDS_READ = + metricRegistry.counter(MetricRegistry.name("shuffleRecordsRead")) + val METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED = + metricRegistry.counter(MetricRegistry.name("shuffleRemoteBlocksFetched")) + val METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED = + metricRegistry.counter(MetricRegistry.name("shuffleLocalBlocksFetched")) + val METRIC_SHUFFLE_BYTES_WRITTEN = + metricRegistry.counter(MetricRegistry.name("shuffleBytesWritten")) + val METRIC_SHUFFLE_RECORDS_WRITTEN = + metricRegistry.counter(MetricRegistry.name("shuffleRecordsWritten")) + val METRIC_INPUT_BYTES_READ = + metricRegistry.counter(MetricRegistry.name("bytesRead")) + val METRIC_INPUT_RECORDS_READ = + metricRegistry.counter(MetricRegistry.name("recordsRead")) + val METRIC_OUTPUT_BYTES_WRITTEN = + metricRegistry.counter(MetricRegistry.name("bytesWritten")) + val METRIC_OUTPUT_RECORDS_WRITTEN = + metricRegistry.counter(MetricRegistry.name("recordsWritten")) + val METRIC_RESULT_SIZE = + metricRegistry.counter(MetricRegistry.name("resultSize")) + val METRIC_DISK_BYTES_SPILLED = + metricRegistry.counter(MetricRegistry.name("diskBytesSpilled")) + val METRIC_MEMORY_BYTES_SPILLED = + metricRegistry.counter(MetricRegistry.name("memoryBytesSpilled")) } From d43e1f06bd545d00bfcaf1efb388b469effd5d64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Nov 2017 18:39:15 +0100 Subject: [PATCH 161/712] [MINOR] Data source v2 docs update. ## What changes were proposed in this pull request? This patch includes some doc updates for data source API v2. I was reading the code and noticed some minor issues. ## How was this patch tested? This is a doc only change. Author: Reynold Xin Closes #19626 from rxin/dsv2-update. --- .../org/apache/spark/sql/sources/v2/DataSourceV2.java | 9 ++++----- .../org/apache/spark/sql/sources/v2/WriteSupport.java | 4 ++-- .../sql/sources/v2/reader/DataSourceV2Reader.java | 10 +++++----- .../v2/reader/SupportsPushDownCatalystFilters.java | 2 -- .../sql/sources/v2/reader/SupportsScanUnsafeRow.java | 2 -- .../sql/sources/v2/writer/DataSourceV2Writer.java | 11 +++-------- .../spark/sql/sources/v2/writer/DataWriter.java | 10 +++++----- 7 files changed, 19 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index dbcbe326a7510..6234071320dc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -20,12 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * The base interface for data source v2. Implementations must have a public, no arguments - * constructor. + * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface, data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport}. Otherwise it's just a dummy data source which is - * un-readable/writable. + * Note that this is an empty interface. Data source implementations should mix-in at least one of + * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just + * a dummy data source which is un-readable/writable. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index a8a961598bde3..8fdfdfd19ea1e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -36,8 +36,8 @@ public interface WriteSupport { * sources can return None if there is no writing needed to be done according to the save mode. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} should - * use this job id to distinguish itself with writers of other jobs. + * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 5989a4ac8440b..88c3219a75c1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -34,11 +34,11 @@ * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. These push-down interfaces are named like `SupportsPushDownXXX`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. These - * reporting interfaces are named like `SupportsReportingXXX`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. These scan interfaces are named - * like `SupportsScanXXX`. + * pruning), etc. Names of these interfaces start with `SupportsPushDown`. + * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. + * Names of these interfaces start with `SupportsReporting`. + * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. + * Names of these interfaces start with `SupportsScan`. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index d6091774d75aa..efc42242f4421 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -31,8 +31,6 @@ * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only * process this interface. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsPushDownCatalystFilters { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index d5eada808a16c..6008fb5f71cc1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -30,8 +30,6 @@ * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsScanUnsafeRow extends DataSourceV2Reader { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 8d8e33633fb0d..37bb15f87c59a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -40,15 +40,10 @@ * some writers are aborted, or the job failed with an unknown reason, call * {@link #abort(WriterCommitMessage[])}. * - * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if - * they want to retry. + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. * - * Please refer to the document of commit/abort methods for detailed specifications. - * - * Note that, this interface provides a protocol between Spark and data sources for transactional - * data writing, but the transaction here is Spark-level transaction, which may not be the - * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data - * source, but Cassandra may need some more time to reach consistency at storage level. + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving public interface DataSourceV2Writer { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index d84afbae32892..dc1aab33bdcef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -57,8 +57,8 @@ public interface DataWriter { /** * Writes one record. * - * If this method fails(throw exception), {@link #abort()} will be called and this data writer is - * considered to be failed. + * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. */ void write(T record); @@ -70,10 +70,10 @@ public interface DataWriter { * The written data should only be visible to data source readers after * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to - * do the final commitment via {@link WriterCommitMessage}. + * do the final commit via {@link WriterCommitMessage}. * - * If this method fails(throw exception), {@link #abort()} will be called and this data writer is - * considered to be failed. + * If this method fails (by throwing an exception), {@link #abort()} will be called and this + * data writer is considered to have been failed. */ WriterCommitMessage commit(); From b04eefae49b96e2ef5a8d75334db29ef4e19ce58 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Nov 2017 09:30:03 +0900 Subject: [PATCH 162/712] [MINOR][DOC] automatic type inference supports also Date and Timestamp ## What changes were proposed in this pull request? Easy fix in the documentation, which is reporting that only numeric types and string are supported in type inference for partition columns, while Date and Timestamp are supported too since 2.1.0, thanks to SPARK-17388. ## How was this patch tested? n/a Author: Marco Gaido Closes #19628 from mgaido91/SPARK-22398. --- docs/sql-programming-guide.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 639a8ea7bb8ad..ce377875ff2b1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -800,10 +800,11 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. Sometimes users may not want to automatically -infer the data types of the partitioning columns. For these use cases, the automatic type inference -can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to -`true`. When type inference is disabled, string type will be used for the partitioning columns. +numeric data types, date, timestamp and string type are supported. Sometimes users may not want +to automatically infer the data types of the partitioning columns. For these use cases, the +automatic type inference can be configured by +`spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type +inference is disabled, string type will be used for the partitioning columns. Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths by default. For the above example, if users pass `path/to/table/gender=male` to either From 849b465bbf472d6ca56308fb3ccade86e2244e01 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 2 Nov 2017 09:45:34 +0000 Subject: [PATCH 163/712] [SPARK-14650][REPL][BUILD] Compile Spark REPL for Scala 2.12 ## What changes were proposed in this pull request? Spark REPL changes for Scala 2.12.4: use command(), not processLine() in ILoop; remove direct dependence on older jline. Not sure whether this became needed in 2.12.4 or just missed this before. This makes spark-shell work in 2.12. ## How was this patch tested? Existing tests; manual run of spark-shell in 2.11, 2.12 builds Author: Sean Owen Closes #19612 from srowen/SPARK-14650.2. --- pom.xml | 15 ++++++++++----- repl/pom.xml | 4 ---- .../scala/org/apache/spark/repl/SparkILoop.scala | 13 +++++-------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/pom.xml b/pom.xml index 652aed4d12f7a..8570338df6878 100644 --- a/pom.xml +++ b/pom.xml @@ -735,6 +735,12 @@ scalap ${scala.version} + + + jline + jline + 2.12.1 + org.scalatest scalatest_${scala.binary.version} @@ -1188,6 +1194,10 @@ org.jboss.netty netty + + jline + jline + @@ -1925,11 +1935,6 @@ antlr4-runtime ${antlr4.version} - - jline - jline - 2.12.1 - org.apache.commons commons-crypto diff --git a/repl/pom.xml b/repl/pom.xml index bd2cfc465aaf0..1cb0098d0eca3 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -69,10 +69,6 @@ org.scala-lang scala-reflect ${scala.version} - - - jline - jline org.slf4j diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 413594021987d..900edd63cb90e 100644 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -19,9 +19,6 @@ package org.apache.spark.repl import java.io.BufferedReader -// scalastyle:off println -import scala.Predef.{println => _, _} -// scalastyle:on println import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} import scala.tools.nsc.util.stringFromStream @@ -37,7 +34,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { - processLine(""" + command(""" @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { org.apache.spark.repl.Main.sparkSession } else { @@ -64,10 +61,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) _sc } """) - processLine("import org.apache.spark.SparkContext._") - processLine("import spark.implicits._") - processLine("import spark.sql") - processLine("import org.apache.spark.sql.functions._") + command("import org.apache.spark.SparkContext._") + command("import spark.implicits._") + command("import spark.sql") + command("import org.apache.spark.sql.functions._") } } From 277b1924b46a70ab25414f5670eb784906dbbfdf Mon Sep 17 00:00:00 2001 From: Patrick Woody Date: Thu, 2 Nov 2017 14:19:21 +0100 Subject: [PATCH 164/712] [SPARK-22408][SQL] RelationalGroupedDataset's distinct pivot value calculation launches unnecessary stages ## What changes were proposed in this pull request? Adding a global limit on top of the distinct values before sorting and collecting will reduce the overall work in the case where we have more distinct values. We will also eagerly perform a collect rather than a take because we know we only have at most (maxValues + 1) rows. ## How was this patch tested? Existing tests cover sorted order Author: Patrick Woody Closes #19629 from pwoody/SPARK-22408. --- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 21e94fa8bb0b1..3e4edd4ea8cf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -321,10 +321,10 @@ class RelationalGroupedDataset protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() + .limit(maxValues + 1) .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .rdd + .collect() .map(_.get(0)) - .take(maxValues + 1) .toSeq if (values.length > maxValues) { From b2463fad718d25f564d62c50d587610de3d0c5bd Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 2 Nov 2017 13:25:48 +0000 Subject: [PATCH 165/712] [SPARK-22145][MESOS] fix supervise with checkpointing on mesos ## What changes were proposed in this pull request? - Fixes the issue with the frameworkId being recovered by checkpointed data overwriting the one sent by the dipatcher. - Keeps submission driver id as the only index for all data structures in the dispatcher. Allocates a different task id per driver retry to satisfy the mesos requirements. Check the relevant ticket for the details on that. ## How was this patch tested? Manually tested this with DC/OS 1.10. Launched a streaming job with checkpointing to hdfs, made the driver fail several times and observed behavior: ![image](https://user-images.githubusercontent.com/7945591/30940500-f7d2a744-a3e9-11e7-8c56-f2ccbb271e80.png) ![image](https://user-images.githubusercontent.com/7945591/30940550-19bc15de-a3ea-11e7-8a11-f48abfe36720.png) ![image](https://user-images.githubusercontent.com/7945591/30940524-083ea308-a3ea-11e7-83ae-00d3fa17b928.png) ![image](https://user-images.githubusercontent.com/7945591/30940579-2f0fb242-a3ea-11e7-82f9-86179da28b8c.png) ![image](https://user-images.githubusercontent.com/7945591/30940591-3b561b0e-a3ea-11e7-9dbd-e71912bb2ef3.png) ![image](https://user-images.githubusercontent.com/7945591/30940605-49c810ca-a3ea-11e7-8af5-67930851fd38.png) ![image](https://user-images.githubusercontent.com/7945591/30940631-59f4a288-a3ea-11e7-88cb-c3741b72bb13.png) ![image](https://user-images.githubusercontent.com/7945591/30940642-62346c9e-a3ea-11e7-8935-82e494925f67.png) ![image](https://user-images.githubusercontent.com/7945591/30940653-6c46d53c-a3ea-11e7-8dd1-5840d484d28c.png) Author: Stavros Kontopoulos Closes #19374 from skonto/fix_retry. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../cluster/mesos/MesosClusterScheduler.scala | 90 +++++++++++-------- .../apache/spark/streaming/Checkpoint.scala | 3 +- 3 files changed, 57 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f25d346e6e54..c7dd635ad4c96 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -310,6 +310,7 @@ class SparkContext(config: SparkConf) extends Logging { * (i.e. * in case of local spark app something like 'local-1433865536131' * in case of YARN something like 'application_1433865536131_34483' + * in case of MESOS something like 'driver-20170926223339-0001' * ) */ def applicationId: String = _applicationId diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 82470264f2a4a..de846c85d53a6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -134,22 +134,24 @@ private[spark] class MesosClusterScheduler( private val useFetchCache = conf.getBoolean("spark.mesos.fetchCache.enable", false) private val schedulerState = engineFactory.createEngine("scheduler") private val stateLock = new Object() + // Keyed by submission id private val finishedDrivers = new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) private var frameworkId: String = null - // Holds all the launched drivers and current launch state, keyed by driver id. + // Holds all the launched drivers and current launch state, keyed by submission id. private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. // All drivers that are loaded after failover are added here, as we need get the latest - // state of the tasks from Mesos. + // state of the tasks from Mesos. Keyed by task Id. private val pendingRecover = new mutable.HashMap[String, SlaveID]() - // Stores all the submitted drivers that hasn't been launched. + // Stores all the submitted drivers that hasn't been launched, keyed by submission id private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() - // All supervised drivers that are waiting to retry after termination. + // All supervised drivers that are waiting to retry after termination, keyed by submission id private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() private val queuedDriversState = engineFactory.createEngine("driverQueue") private val launchedDriversState = engineFactory.createEngine("launchedDrivers") private val pendingRetryDriversState = engineFactory.createEngine("retryList") + private final val RETRY_SEP = "-retry-" // Flag to mark if the scheduler is ready to be called, which is until the scheduler // is registered with Mesos master. @volatile protected var ready = false @@ -192,8 +194,8 @@ private[spark] class MesosClusterScheduler( // 3. Check if it's in the retry list. // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { - val task = launchedDrivers(submissionId) - schedulerDriver.killTask(task.taskId) + val state = launchedDrivers(submissionId) + schedulerDriver.killTask(state.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -275,7 +277,7 @@ private[spark] class MesosClusterScheduler( private def recoverState(): Unit = { stateLock.synchronized { launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => - launchedDrivers(state.taskId.getValue) = state + launchedDrivers(state.driverDescription.submissionId) = state pendingRecover(state.taskId.getValue) = state.slaveId } queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) @@ -353,7 +355,8 @@ private[spark] class MesosClusterScheduler( .setSlaveId(slaveId) .setState(MesosTaskState.TASK_STAGING) .build() - launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + launchedDrivers.get(getSubmissionIdFromTaskId(taskId)) + .map(_.mesosTaskStatus.getOrElse(newStatus)) .getOrElse(newStatus) } // TODO: Page the status updates to avoid trying to reconcile @@ -369,10 +372,19 @@ private[spark] class MesosClusterScheduler( } private def getDriverFrameworkID(desc: MesosDriverDescription): String = { - val retries = desc.retryState.map { d => s"-retry-${d.retries.toString}" }.getOrElse("") + val retries = desc.retryState.map { d => s"${RETRY_SEP}${d.retries.toString}" }.getOrElse("") s"${frameworkId}-${desc.submissionId}${retries}" } + private def getDriverTaskId(desc: MesosDriverDescription): String = { + val sId = desc.submissionId + desc.retryState.map(state => sId + s"${RETRY_SEP}${state.retries.toString}").getOrElse(sId) + } + + private def getSubmissionIdFromTaskId(taskId: String): String = { + taskId.split(s"${RETRY_SEP}").head + } + private def adjust[A, B](m: collection.Map[A, B], k: A, default: B)(f: B => B) = { m.updated(k, f(m.getOrElse(k, default))) } @@ -551,7 +563,7 @@ private[spark] class MesosClusterScheduler( } private def createTaskInfo(desc: MesosDriverDescription, offer: ResourceOffer): TaskInfo = { - val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() + val taskId = TaskID.newBuilder().setValue(getDriverTaskId(desc)).build() val (remainingResources, cpuResourcesToUse) = partitionResources(offer.remainingResources, "cpus", desc.cores) @@ -604,7 +616,7 @@ private[spark] class MesosClusterScheduler( val task = createTaskInfo(submission, offer) queuedTasks += task logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + - submission.submissionId) + submission.submissionId + s" with taskId: ${task.getTaskId.toString}") val newState = new MesosClusterSubmissionState( submission, task.getTaskId, @@ -718,45 +730,51 @@ private[spark] class MesosClusterScheduler( logInfo(s"Received status update: taskId=${taskId}" + s" state=${status.getState}" + s" message=${status.getMessage}" + - s" reason=${status.getReason}"); + s" reason=${status.getReason}") stateLock.synchronized { - if (launchedDrivers.contains(taskId)) { + val subId = getSubmissionIdFromTaskId(taskId) + if (launchedDrivers.contains(subId)) { if (status.getReason == Reason.REASON_RECONCILIATION && !pendingRecover.contains(taskId)) { // Task has already received update and no longer requires reconciliation. return } - val state = launchedDrivers(taskId) + val state = launchedDrivers(subId) // Check if the driver is supervise enabled and can be relaunched. if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { - removeFromLaunchedDrivers(taskId) + removeFromLaunchedDrivers(subId) state.finishDate = Some(new Date()) val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState val (retries, waitTimeSec) = retryState .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } .getOrElse{ (1, 1) } val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) - val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - addDriverToPending(newDriverDescription, taskId); + addDriverToPending(newDriverDescription, newDriverDescription.submissionId) } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { - removeFromLaunchedDrivers(taskId) - state.finishDate = Some(new Date()) - if (finishedDrivers.size >= retainedDrivers) { - val toRemove = math.max(retainedDrivers / 10, 1) - finishedDrivers.trimStart(toRemove) - } - finishedDrivers += state + retireDriver(subId, state) } state.mesosTaskStatus = Option(status) } else { - logError(s"Unable to find driver $taskId in status update") + logError(s"Unable to find driver with $taskId in status update") } } } + private def retireDriver( + submissionId: String, + state: MesosClusterSubmissionState) = { + removeFromLaunchedDrivers(submissionId) + state.finishDate = Some(new Date()) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + override def frameworkMessage( driver: SchedulerDriver, executorId: ExecutorID, @@ -769,31 +787,31 @@ private[spark] class MesosClusterScheduler( slaveId: SlaveID, status: Int): Unit = {} - private def removeFromQueuedDrivers(id: String): Boolean = { - val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + private def removeFromQueuedDrivers(subId: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(subId)) if (index != -1) { queuedDrivers.remove(index) - queuedDriversState.expunge(id) + queuedDriversState.expunge(subId) true } else { false } } - private def removeFromLaunchedDrivers(id: String): Boolean = { - if (launchedDrivers.remove(id).isDefined) { - launchedDriversState.expunge(id) + private def removeFromLaunchedDrivers(subId: String): Boolean = { + if (launchedDrivers.remove(subId).isDefined) { + launchedDriversState.expunge(subId) true } else { false } } - private def removeFromPendingRetryDrivers(id: String): Boolean = { - val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + private def removeFromPendingRetryDrivers(subId: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(subId)) if (index != -1) { pendingRetryDrivers.remove(index) - pendingRetryDriversState.expunge(id) + pendingRetryDriversState.expunge(subId) true } else { false @@ -810,8 +828,8 @@ private[spark] class MesosClusterScheduler( revive() } - private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { - pendingRetryDriversState.persist(taskId, desc) + private def addDriverToPending(desc: MesosDriverDescription, subId: String) = { + pendingRetryDriversState.persist(subId, desc) pendingRetryDrivers += desc revive() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b8c780db07c98..40a0b8e3a407d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -58,7 +58,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.credentials.file", "spark.yarn.credentials.renewalTime", "spark.yarn.credentials.updateTime", - "spark.ui.filters") + "spark.ui.filters", + "spark.mesos.driver.frameworkId") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") From 41b60125b673bad0c133cd5c825d353ac2e6dfd6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 2 Nov 2017 15:22:52 +0100 Subject: [PATCH 166/712] [SPARK-22369][PYTHON][DOCS] Exposes catalog API documentation in PySpark ## What changes were proposed in this pull request? This PR proposes to add a link from `spark.catalog(..)` to `Catalog` and expose Catalog APIs in PySpark as below: 2017-10-29 12 25 46 2017-10-29 12 26 33 Note that this is not shown in the list on the top - https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#module-pyspark.sql 2017-10-29 12 30 58 This is basically similar with `DataFrameReader` and `DataFrameWriter`. ## How was this patch tested? Manually built the doc. Author: hyukjinkwon Closes #19596 from HyukjinKwon/SPARK-22369. --- python/pyspark/sql/__init__.py | 3 ++- python/pyspark/sql/session.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 22ec416f6c584..c3c06c8124362 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -46,6 +46,7 @@ from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column +from pyspark.sql.catalog import Catalog from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter @@ -54,7 +55,7 @@ __all__ = [ 'SparkSession', 'SQLContext', 'HiveContext', 'UDFRegistration', - 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 2cc0e2d1d7b8d..c3dc1a46fd3c1 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -271,6 +271,8 @@ def conf(self): def catalog(self): """Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc. + + :return: :class:`Catalog` """ if not hasattr(self, "_catalog"): self._catalog = Catalog(self) From e3f67a97f126abfb7eeb864f657bfc9221bb195e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2017 18:28:56 +0100 Subject: [PATCH 167/712] [SPARK-22416][SQL] Move OrcOptions from `sql/hive` to `sql/core` ## What changes were proposed in this pull request? According to the [discussion](https://github.com/apache/spark/pull/19571#issuecomment-339472976) on SPARK-15474, we will add new OrcFileFormat in `sql/core` module and allow users to use both old and new OrcFileFormat. To do that, `OrcOptions` should be visible in `sql/core` module, too. Previously, it was `private[orc]` in `sql/hive`. This PR removes `private[orc]` because we don't use `private[sql]` in `sql/execution` package after [SPARK-16964](https://github.com/apache/spark/pull/14554). ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #19636 from dongjoon-hyun/SPARK-22416. --- .../spark/sql/execution/datasources}/orc/OrcOptions.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala | 1 + .../org/apache/spark/sql/hive/orc/OrcSourceSuite.scala | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) rename sql/{hive/src/main/scala/org/apache/spark/sql/hive => core/src/main/scala/org/apache/spark/sql/execution/datasources}/orc/OrcOptions.scala (95%) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala similarity index 95% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index 6ce90c07b4921..c866dd834a525 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.util.Locale @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the ORC data source. */ -private[orc] class OrcOptions( +class OrcOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -59,7 +59,7 @@ private[orc] class OrcOptions( } } -private[orc] object OrcOptions { +object OrcOptions { // The ORC compression short names private val shortOrcCompressionCodecNames = Map( "none" -> "NONE", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index d26ec15410d95..3b33a9ff082f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index ef9e67c743837..2a086be57f517 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -24,6 +24,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ From 882079f5c6de40913a27873f2fa3306e5d827393 Mon Sep 17 00:00:00 2001 From: ZouChenjun Date: Thu, 2 Nov 2017 11:06:37 -0700 Subject: [PATCH 168/712] [SPARK-22243][DSTREAM] spark.yarn.jars should reload from config when checkpoint recovery ## What changes were proposed in this pull request? the previous [PR](https://github.com/apache/spark/pull/19469) is deleted by mistake. the solution is straight forward. adding "spark.yarn.jars" to propertiesToReload so this property will load from config. ## How was this patch tested? manual tests Author: ZouChenjun Closes #19637 from ChenjunZou/checkpoint-yarn-jars. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 40a0b8e3a407d..9ebb91b8cab3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,6 +53,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.port", "spark.master", + "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", "spark.yarn.credentials.file", From 2fd12af4372a1e2c3faf0eb5d0a1cf530abc0016 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Nov 2017 23:41:16 +0100 Subject: [PATCH 169/712] [SPARK-22306][SQL] alter table schema should not erase the bucketing metadata at hive side forward-port https://github.com/apache/spark/pull/19622 to master branch. This bug doesn't exist in master because we've added hive bucketing support and the hive bucketing metadata can be recognized by Spark, but we should still port it to master: 1) there may be other unsupported hive metadata removed by Spark. 2) reduce code difference between master and 2.2 to ease the backport in the feature. *** When we alter table schema, we set the new schema to spark `CatalogTable`, convert it to hive table, and finally call `hive.alterTable`. This causes a problem in Spark 2.2, because hive bucketing metedata is not recognized by Spark, which means a Spark `CatalogTable` representing a hive table is always non-bucketed, and when we convert it to hive table and call `hive.alterTable`, the original hive bucketing metadata will be removed. To fix this bug, we should read out the raw hive table metadata, update its schema, and call `hive.alterTable`. By doing this we can guarantee only the schema is changed, and nothing else. Author: Wenchen Fan Closes #19644 from cloud-fan/infer. --- .../catalyst/catalog/ExternalCatalog.scala | 12 ++--- .../catalyst/catalog/InMemoryCatalog.scala | 7 +-- .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++----- .../catalog/ExternalCatalogSuite.scala | 10 ++-- .../catalog/SessionCatalogSuite.scala | 8 ++-- .../spark/sql/execution/command/ddl.scala | 14 ++++-- .../spark/sql/execution/command/tables.scala | 14 ++---- .../datasources/DataSourceStrategy.scala | 4 +- .../datasources/orc/OrcFileFormat.scala | 5 +- .../parquet/ParquetFileFormat.scala | 5 +- .../parquet/ParquetSchemaConverter.scala | 5 +- .../spark/sql/hive/HiveExternalCatalog.scala | 47 ++++++++++++------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 11 ++--- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../spark/sql/hive/client/HiveClient.scala | 11 +++++ .../sql/hive/client/HiveClientImpl.scala | 45 +++++++++++------- .../sql/hive/HiveExternalCatalogSuite.scala | 18 +++++++ .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../hive/execution/Hive_2_1_DDLSuite.scala | 2 +- 19 files changed, 146 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index d4c58db3708e3..223094d485936 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -150,17 +150,15 @@ abstract class ExternalCatalog def alterTable(tableDefinition: CatalogTable): Unit /** - * Alter the schema of a table identified by the provided database and table name. The new schema - * should still contain the existing bucket columns and partition columns used by the table. This - * method will also update any Spark SQL-related parameters stored as Hive table properties (such - * as the schema itself). + * Alter the data schema of a table identified by the provided database and table name. The new + * data schema should not have conflict column names with the existing partition columns, and + * should still contain all the existing data columns. * * @param db Database that table to alter schema for exists in * @param table Name of table to alter schema for - * @param schema Updated schema to be used for the table (must contain existing partition and - * bucket columns) + * @param newDataSchema Updated data schema to be used for the table. */ - def alterTableSchema(db: String, table: String, schema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 98370c12a977c..9504140d51e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -303,13 +303,14 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def alterTableSchema( + override def alterTableDataSchema( db: String, table: String, - schema: StructType): Unit = synchronized { + newDataSchema: StructType): Unit = synchronized { requireTableExists(db, table) val origTable = catalog(db).tables(table).table - catalog(db).tables(table).table = origTable.copy(schema = schema) + val newSchema = StructType(newDataSchema ++ origTable.partitionSchema) + catalog(db).tables(table).table = origTable.copy(schema = newSchema) } override def alterTableStats( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 95bc3d674b4f8..a129896230775 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -324,18 +324,16 @@ class SessionCatalog( } /** - * Alter the schema of a table identified by the provided table identifier. The new schema - * should still contain the existing bucket columns and partition columns used by the table. This - * method will also update any Spark SQL-related parameters stored as Hive table properties (such - * as the schema itself). + * Alter the data schema of a table identified by the provided table identifier. The new data + * schema should not have conflict column names with the existing partition columns, and should + * still contain all the existing data columns. * * @param identifier TableIdentifier - * @param newSchema Updated schema to be used for the table (must contain existing partition and - * bucket columns, and partition columns need to be at the end) + * @param newDataSchema Updated data schema to be used for the table */ - def alterTableSchema( + def alterTableDataSchema( identifier: TableIdentifier, - newSchema: StructType): Unit = { + newDataSchema: StructType): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -343,10 +341,10 @@ class SessionCatalog( requireTableExists(tableIdentifier) val catalogTable = externalCatalog.getTable(db, table) - val oldSchema = catalogTable.schema - + val oldDataSchema = catalogTable.dataSchema // not supporting dropping columns yet - val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + val nonExistentColumnNames = + oldDataSchema.map(_.name).filterNot(columnNameResolved(newDataSchema, _)) if (nonExistentColumnNames.nonEmpty) { throw new AnalysisException( s""" @@ -355,8 +353,7 @@ class SessionCatalog( """.stripMargin) } - // assuming the newSchema has all partition columns at the end as required - externalCatalog.alterTableSchema(db, table, newSchema) + externalCatalog.alterTableDataSchema(db, table, newDataSchema) } private def columnNameResolved(schema: StructType, colName: String): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 94593ef7efa50..b376108399c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -245,14 +245,12 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter table schema") { val catalog = newBasicCatalog() - val newSchema = StructType(Seq( + val newDataSchema = StructType(Seq( StructField("col1", IntegerType), - StructField("new_field_2", StringType), - StructField("a", IntegerType), - StructField("b", StringType))) - catalog.alterTableSchema("db2", "tbl1", newSchema) + StructField("new_field_2", StringType))) + catalog.alterTableDataSchema("db2", "tbl1", newDataSchema) val newTbl1 = catalog.getTable("db2", "tbl1") - assert(newTbl1.schema == newSchema) + assert(newTbl1.dataSchema == newDataSchema) } test("alter table stats") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 1cce19989c60e..95c87ffa20cb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -463,9 +463,9 @@ abstract class SessionCatalogSuite extends AnalysisTest { withBasicCatalog { sessionCatalog => sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") - sessionCatalog.alterTableSchema( + sessionCatalog.alterTableDataSchema( TableIdentifier("t1", Some("default")), - StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema)) + StructType(oldTab.dataSchema.add("c3", IntegerType))) val newTab = sessionCatalog.externalCatalog.getTable("default", "t1") // construct the expected table schema @@ -480,8 +480,8 @@ abstract class SessionCatalogSuite extends AnalysisTest { sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") val e = intercept[AnalysisException] { - sessionCatalog.alterTableSchema( - TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1))) + sessionCatalog.alterTableDataSchema( + TableIdentifier("t1", Some("default")), StructType(oldTab.dataSchema.drop(1))) }.getMessage assert(e.contains("We don't support dropping columns yet.")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 162e1d5be2938..a9cd65e3242c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -857,19 +857,23 @@ object DDLUtils { } } - private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = { + private[sql] def checkDataColNames(table: CatalogTable): Unit = { + checkDataColNames(table, table.dataSchema.fieldNames) + } + + private[sql] def checkDataColNames(table: CatalogTable, colNames: Seq[String]): Unit = { table.provider.foreach { _.toLowerCase(Locale.ROOT) match { case HIVE_PROVIDER => val serde = table.storage.serde if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { - OrcFileFormat.checkFieldNames(table.dataSchema) + OrcFileFormat.checkFieldNames(colNames) } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || serde == Some("parquet.hive.serde.ParquetHiveSerDe")) { - ParquetSchemaConverter.checkFieldNames(table.dataSchema) + ParquetSchemaConverter.checkFieldNames(colNames) } - case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema) - case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema) + case "parquet" => ParquetSchemaConverter.checkFieldNames(colNames) + case "orc" => OrcFileFormat.checkFieldNames(colNames) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 38f91639c0422..95f16b0f4baea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -186,7 +186,7 @@ case class AlterTableRenameCommand( */ case class AlterTableAddColumnsCommand( table: TableIdentifier, - columns: Seq[StructField]) extends RunnableCommand { + colsToAdd: Seq[StructField]) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val catalogTable = verifyAlterTableAddColumn(catalog, table) @@ -199,17 +199,13 @@ case class AlterTableAddColumnsCommand( } catalog.refreshTable(table) - // make sure any partition columns are at the end of the fields - val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema - val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray) - SchemaUtils.checkColumnNameDuplication( - reorderedSchema.map(_.name), "in the table definition of " + table.identifier, + (colsToAdd ++ catalogTable.schema).map(_.name), + "in the table definition of " + table.identifier, conf.caseSensitiveAnalysis) - DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema)) - - catalog.alterTableSchema(table, newSchema) + DDLUtils.checkDataColNames(catalogTable, colsToAdd.map(_.name)) + catalog.alterTableDataSchema(table, StructType(catalogTable.dataSchema ++ colsToAdd)) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 018f24e290b4b..04d6f3f56eb02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -133,12 +133,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema)) + DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2eeb0065455f3..215740e90fe84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -35,8 +35,7 @@ private[sql] object OrcFileFormat { } } - def checkFieldNames(schema: StructType): StructType = { - schema.fieldNames.foreach(checkFieldName) - schema + def checkFieldNames(names: Seq[String]): Unit = { + names.foreach(checkFieldName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index c1535babbae1f..61bd65dd48144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -23,7 +23,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.forkjoin.ForkJoinPool import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -306,10 +305,10 @@ class ParquetFileFormat hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetSchemaConverter.checkFieldNames(requiredSchema).json) + requiredSchema.json) hadoopConf.set( ParquetWriteSupport.SPARK_ROW_SCHEMA, - ParquetSchemaConverter.checkFieldNames(requiredSchema).json) + requiredSchema.json) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index b3781cfc4a607..cd384d17d0cda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -571,9 +571,8 @@ private[sql] object ParquetSchemaConverter { """.stripMargin.split("\n").mkString(" ").trim) } - def checkFieldNames(schema: StructType): StructType = { - schema.fieldNames.foreach(checkFieldName) - schema + def checkFieldNames(names: Seq[String]): Unit = { + names.foreach(checkFieldName) } def checkConversionRequirement(f: => Boolean, message: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 96dc983b0bfc6..f8a947bf527e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -138,16 +138,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of column names. Hive metastore disallows the table to use comma in + * Checks the validity of data column names. Hive metastore disallows the table to use comma in * data column names. Partition columns do not have such a restriction. Views do not have such * a restriction. */ - private def verifyColumnNames(table: CatalogTable): Unit = { - if (table.tableType != VIEW) { - table.dataSchema.map(_.name).foreach { colName => + private def verifyDataSchema( + tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { + if (tableType != VIEW) { + dataSchema.map(_.name).foreach { colName => if (colName.contains(",")) { throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + s"commas in Hive metastore. Table: $tableName; Column: $colName") } } } @@ -218,7 +219,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) - verifyColumnNames(tableDefinition) + verifyDataSchema( + tableDefinition.identifier, tableDefinition.tableType, tableDefinition.dataSchema) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -296,7 +298,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat storage = table.storage.copy( locationUri = None, properties = storagePropsWithLocation), - schema = table.partitionSchema, + schema = StructType(EMPTY_DATA_SCHEMA ++ table.partitionSchema), bucketSpec = None, properties = table.properties ++ tableProperties) } @@ -617,32 +619,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { + override def alterTableDataSchema( + db: String, table: String, newDataSchema: StructType): Unit = withClient { requireTableExists(db, table) - val rawTable = getRawTable(db, table) - // Add table metadata such as table schema, partition columns, etc. to table properties. - val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema) - val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema) - verifyColumnNames(withNewSchema) + val oldTable = getTable(db, table) + verifyDataSchema(oldTable.identifier, oldTable.tableType, newDataSchema) + val schemaProps = + tableMetaToTableProps(oldTable, StructType(newDataSchema ++ oldTable.partitionSchema)).toMap - if (isDatasourceTable(rawTable)) { + if (isDatasourceTable(oldTable)) { // For data source tables, first try to write it with the schema set; if that does not work, // try again with updated properties and the partition schema. This is a simplified version of // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive // (for example, the schema does not match the data source schema, or does not match the // storage descriptor). try { - client.alterTable(withNewSchema) + client.alterTableDataSchema(db, table, newDataSchema, schemaProps) } catch { case NonFatal(e) => val warningMessage = - s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + s"Could not alter schema of table ${oldTable.identifier.quotedString} in a Hive " + "compatible way. Updating Hive metastore in Spark SQL specific format." logWarning(warningMessage, e) - client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema)) + client.alterTableDataSchema(db, table, EMPTY_DATA_SCHEMA, schemaProps) } } else { - client.alterTable(withNewSchema) + client.alterTableDataSchema(db, table, newDataSchema, schemaProps) } } @@ -1297,6 +1299,15 @@ object HiveExternalCatalog { val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + // When storing data source tables in hive metastore, we need to set data schema to empty if the + // schema is hive-incompatible. However we need a hack to preserve existing behavior. Before + // Spark 2.0, we do not set a default serde here (this was done in Hive), and so if the user + // provides an empty schema Hive would automatically populate the schema with a single field + // "col". However, after SPARK-14388, we set the default serde to LazySimpleSerde so this + // implicit behavior no longer happens. Therefore, we need to do it in Spark ourselves. + val EMPTY_DATA_SCHEMA = new StructType() + .add("col", "array", nullable = true, comment = "from deserializer") + /** * Returns the fully qualified name used in table properties for a particular column stat. * For example, for column "mycol", and "min" stat, this should return diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5ac65973e70e1..8adfda07d29d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -244,11 +244,11 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log inferredSchema match { case Some(dataSchema) => - val schema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) if (inferenceMode == INFER_AND_SAVE) { - updateCatalogSchema(relation.tableMeta.identifier, schema) + updateDataSchema(relation.tableMeta.identifier, dataSchema) } - relation.tableMeta.copy(schema = schema) + val newSchema = StructType(dataSchema ++ relation.tableMeta.partitionSchema) + relation.tableMeta.copy(schema = newSchema) case None => logWarning(s"Unable to infer schema for table $tableName from file format " + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") @@ -259,10 +259,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try { - val db = identifier.database.get + private def updateDataSchema(identifier: TableIdentifier, newDataSchema: StructType): Unit = try { logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}") - sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema) + sparkSession.sessionState.catalog.alterTableDataSchema(identifier, newDataSchema) } catch { case NonFatal(ex) => logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3592b8f4846d1..ee1f6ee173063 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -152,11 +152,11 @@ object HiveAnalysis extends Rule[LogicalPlan] { InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => - DDLUtils.checkDataSchemaFieldNames(tableDesc) + DDLUtils.checkDataColNames(tableDesc) CreateHiveTableAsSelectCommand(tableDesc, query, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index ee3eb2ee8abe5..f69717441d615 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType /** @@ -100,6 +101,16 @@ private[hive] trait HiveClient { */ def alterTable(dbName: String, tableName: String, table: CatalogTable): Unit + /** + * Updates the given table with a new data schema and table properties, and keep everything else + * unchanged. + * + * TODO(cloud-fan): it's a little hacky to introduce the schema table properties here in + * `HiveClient`, but we don't have a cleaner solution now. + */ + def alterTableDataSchema( + dbName: String, tableName: String, newDataSchema: StructType, schemaProps: Map[String, String]) + /** Creates a new database with the given name. */ def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 16c95c53b4201..b5a5890d47b03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -39,7 +39,6 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.AnalysisException @@ -51,7 +50,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveExternalCatalog.{DATASOURCE_SCHEMA, DATASOURCE_SCHEMA_NUMPARTS, DATASOURCE_SCHEMA_PART_PREFIX} import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -515,6 +514,33 @@ private[hive] class HiveClientImpl( shim.alterTable(client, qualifiedTableName, hiveTable) } + override def alterTableDataSchema( + dbName: String, + tableName: String, + newDataSchema: StructType, + schemaProps: Map[String, String]): Unit = withHiveState { + val oldTable = client.getTable(dbName, tableName) + val hiveCols = newDataSchema.map(toHiveColumn) + oldTable.setFields(hiveCols.asJava) + + // remove old schema table properties + val it = oldTable.getParameters.entrySet.iterator + while (it.hasNext) { + val entry = it.next() + val isSchemaProp = entry.getKey.startsWith(DATASOURCE_SCHEMA_PART_PREFIX) || + entry.getKey == DATASOURCE_SCHEMA || entry.getKey == DATASOURCE_SCHEMA_NUMPARTS + if (isSchemaProp) { + it.remove() + } + } + + // set new schema table properties + schemaProps.foreach { case (k, v) => oldTable.setProperty(k, v) } + + val qualifiedTableName = s"$dbName.$tableName" + shim.alterTable(client, qualifiedTableName, oldTable) + } + override def createPartitions( db: String, table: String, @@ -896,20 +922,7 @@ private[hive] object HiveClientImpl { val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => table.partitionColumnNames.contains(c.getName) } - // after SPARK-19279, it is not allowed to create a hive table with an empty schema, - // so here we should not add a default col schema - if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) { - // This is a hack to preserve existing behavior. Before Spark 2.0, we do not - // set a default serde here (this was done in Hive), and so if the user provides - // an empty schema Hive would automatically populate the schema with a single - // field "col". However, after SPARK-14388, we set the default serde to - // LazySimpleSerde so this implicit behavior no longer happens. Therefore, - // we need to do it in Spark ourselves. - hiveTable.setFields( - Seq(new FieldSchema("col", "array", "from deserializer")).asJava) - } else { - hiveTable.setFields(schema.asJava) - } + hiveTable.setFields(schema.asJava) hiveTable.setPartCols(partCols.asJava) userName.foreach(hiveTable.setOwner) hiveTable.setCreateTime((table.createTime / 1000).toInt) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index d43534d5914d1..2e35fdeba464d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -89,4 +89,22 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { assert(restoredTable.schema == newSchema) } } + + test("SPARK-22306: alter table schema should not erase the bucketing metadata at hive side") { + val catalog = newBasicCatalog() + externalCatalog.client.runSqlHive( + """ + |CREATE TABLE db1.t(a string, b string) + |CLUSTERED BY (a, b) SORTED BY (a, b) INTO 10 BUCKETS + |STORED AS PARQUET + """.stripMargin) + + val newSchema = new StructType().add("a", "string").add("b", "string").add("c", "string") + catalog.alterTableDataSchema("db1", "t", newSchema) + + assert(catalog.getTable("db1", "t").schema == newSchema) + val bucketString = externalCatalog.client.runSqlHive("DESC FORMATTED db1.t") + .filter(_.contains("Num Buckets")).head + assert(bucketString.contains("10")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f5d41c91270a5..a1060476f2211 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -741,7 +741,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val hiveTable = CatalogTable( identifier = TableIdentifier(tableName, Some("default")), tableType = CatalogTableType.MANAGED, - schema = new StructType, + schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA, provider = Some("json"), storage = CatalogStorageFormat( locationUri = None, @@ -1266,7 +1266,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val hiveTable = CatalogTable( identifier = TableIdentifier("t", Some("default")), tableType = CatalogTableType.MANAGED, - schema = new StructType, + schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA, provider = Some("json"), storage = CatalogStorageFormat.empty, properties = Map( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index 5c248b9acd04f..bc828877e35ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -117,7 +117,7 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before spark.sql(createTableStmt) val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName) catalog.createTable(oldTable, true) - catalog.alterTableSchema("default", tableName, updatedSchema) + catalog.alterTableDataSchema("default", tableName, updatedSchema) val updatedTable = catalog.getTable("default", tableName) assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames) From 51145f13768ec04c14762397527b1bc2648e2374 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Fri, 3 Nov 2017 12:20:17 +0000 Subject: [PATCH 170/712] [SPARK-22407][WEB-UI] Add rdd id column on storage page to speed up navigating ## What changes were proposed in this pull request? Add rdd id column on storage page to speed up navigating. Example has attached on [SPARK-22407](https://issues.apache.org/jira/browse/SPARK-22407) An example below: ![add-rddid](https://user-images.githubusercontent.com/26762018/32361127-da0758ac-c097-11e7-9f8c-0ea7ffb87e12.png) ![rdd-cache](https://user-images.githubusercontent.com/26762018/32361128-da3c1574-c097-11e7-8ab1-2def66466f33.png) ## How was this patch tested? Current unit test and manually deploy an history server for testing Author: zhoukang Closes #19625 from caneGuy/zhoukang/add-rddid. --- .../scala/org/apache/spark/ui/storage/StoragePage.scala | 2 ++ .../org/apache/spark/ui/storage/StoragePageSuite.scala | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index aa84788f1df88..b6c764d1728e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -49,6 +49,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Header fields for the RDD table */ private val rddHeader = Seq( + "ID", "RDD Name", "Storage Level", "Cached Partitions", @@ -60,6 +61,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private def rddRow(rdd: RDDInfo): Seq[Node] = { // scalastyle:off + {rdd.id} {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 350c174e24742..4a48b3c686725 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -57,6 +57,7 @@ class StoragePageSuite extends SparkFunSuite { val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) val headers = Seq( + "ID", "RDD Name", "Storage Level", "Cached Partitions", @@ -67,19 +68,19 @@ class StoragePageSuite extends SparkFunSuite { assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) + Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) + Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) + Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=3")) From 89158866085ee3aa18759efaa7d3b3846b9c6504 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Nov 2017 22:03:58 -0700 Subject: [PATCH 171/712] [SPARK-22418][SQL][TEST] Add test cases for NULL Handling ## What changes were proposed in this pull request? Added a test class to check NULL handling behavior. The expected behavior is defined as the one of the most well-known databases as specified here: https://sqlite.org/nulls.html. SparkSQL behaves like other DBs: - Adding anything to null gives null -> YES - Multiplying null by zero gives null -> YES - nulls are distinct in SELECT DISTINCT -> NO - nulls are distinct in a UNION -> NO - "CASE WHEN null THEN 1 ELSE 0 END" is 0? -> YES - "null OR true" is true -> YES - "not (null AND false)" is true -> YES - null in aggregation are skipped -> YES ## How was this patch tested? Added test class Author: Marco Gaido Closes #19653 from mgaido91/SPARK-22418. --- .../sql-tests/inputs/null-handling.sql | 48 +++ .../sql-tests/results/null-handling.sql.out | 305 ++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/null-handling.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/null-handling.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql new file mode 100644 index 0000000000000..b90b0a6ac7500 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql @@ -0,0 +1,48 @@ +-- Create a test table with data +create table t1(a int, b int, c int) using parquet; +insert into t1 values(1,0,0); +insert into t1 values(2,0,1); +insert into t1 values(3,1,0); +insert into t1 values(4,1,1); +insert into t1 values(5,null,0); +insert into t1 values(6,null,1); +insert into t1 values(7,null,null); + +-- Adding anything to null gives null +select a, b+c from t1; + +-- Multiplying null by zero gives null +select a+10, b*0 from t1; + +-- nulls are NOT distinct in SELECT DISTINCT +select distinct b from t1; + +-- nulls are NOT distinct in UNION +select b from t1 union select b from t1; + +-- CASE WHEN null THEN 1 ELSE 0 END is 0 +select a+20, case b when c then 1 else 0 end from t1; +select a+30, case c when b then 1 else 0 end from t1; +select a+40, case when b<>0 then 1 else 0 end from t1; +select a+50, case when not b<>0 then 1 else 0 end from t1; +select a+60, case when b<>0 and c<>0 then 1 else 0 end from t1; + +-- "not (null AND false)" is true +select a+70, case when not (b<>0 and c<>0) then 1 else 0 end from t1; + +-- "null OR true" is true +select a+80, case when b<>0 or c<>0 then 1 else 0 end from t1; +select a+90, case when not (b<>0 or c<>0) then 1 else 0 end from t1; + +-- null with aggregate operators +select count(*), count(b), sum(b), avg(b), min(b), max(b) from t1; + +-- Check the behavior of NULLs in WHERE clauses +select a+100 from t1 where b<10; +select a+110 from t1 where not b>10; +select a+120 from t1 where b<10 OR c=1; +select a+130 from t1 where b<10 AND c=1; +select a+140 from t1 where not (b<10 AND c=1); +select a+150 from t1 where not (c=1 AND b<10); + +drop table t1; diff --git a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out new file mode 100644 index 0000000000000..5005dfeb6cd14 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out @@ -0,0 +1,305 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +create table t1(a int, b int, c int) using parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +insert into t1 values(1,0,0) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +insert into t1 values(2,0,1) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +insert into t1 values(3,1,0) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +insert into t1 values(4,1,1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +insert into t1 values(5,null,0) +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +insert into t1 values(6,null,1) +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +insert into t1 values(7,null,null) +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +select a, b+c from t1 +-- !query 8 schema +struct +-- !query 8 output +1 0 +2 1 +3 1 +4 2 +5 NULL +6 NULL +7 NULL + + +-- !query 9 +select a+10, b*0 from t1 +-- !query 9 schema +struct<(a + 10):int,(b * 0):int> +-- !query 9 output +11 0 +12 0 +13 0 +14 0 +15 NULL +16 NULL +17 NULL + + +-- !query 10 +select distinct b from t1 +-- !query 10 schema +struct +-- !query 10 output +0 +1 +NULL + + +-- !query 11 +select b from t1 union select b from t1 +-- !query 11 schema +struct +-- !query 11 output +0 +1 +NULL + + +-- !query 12 +select a+20, case b when c then 1 else 0 end from t1 +-- !query 12 schema +struct<(a + 20):int,CASE WHEN (b = c) THEN 1 ELSE 0 END:int> +-- !query 12 output +21 1 +22 0 +23 0 +24 1 +25 0 +26 0 +27 0 + + +-- !query 13 +select a+30, case c when b then 1 else 0 end from t1 +-- !query 13 schema +struct<(a + 30):int,CASE WHEN (c = b) THEN 1 ELSE 0 END:int> +-- !query 13 output +31 1 +32 0 +33 0 +34 1 +35 0 +36 0 +37 0 + + +-- !query 14 +select a+40, case when b<>0 then 1 else 0 end from t1 +-- !query 14 schema +struct<(a + 40):int,CASE WHEN (NOT (b = 0)) THEN 1 ELSE 0 END:int> +-- !query 14 output +41 0 +42 0 +43 1 +44 1 +45 0 +46 0 +47 0 + + +-- !query 15 +select a+50, case when not b<>0 then 1 else 0 end from t1 +-- !query 15 schema +struct<(a + 50):int,CASE WHEN (NOT (NOT (b = 0))) THEN 1 ELSE 0 END:int> +-- !query 15 output +51 1 +52 1 +53 0 +54 0 +55 0 +56 0 +57 0 + + +-- !query 16 +select a+60, case when b<>0 and c<>0 then 1 else 0 end from t1 +-- !query 16 schema +struct<(a + 60):int,CASE WHEN ((NOT (b = 0)) AND (NOT (c = 0))) THEN 1 ELSE 0 END:int> +-- !query 16 output +61 0 +62 0 +63 0 +64 1 +65 0 +66 0 +67 0 + + +-- !query 17 +select a+70, case when not (b<>0 and c<>0) then 1 else 0 end from t1 +-- !query 17 schema +struct<(a + 70):int,CASE WHEN (NOT ((NOT (b = 0)) AND (NOT (c = 0)))) THEN 1 ELSE 0 END:int> +-- !query 17 output +71 1 +72 1 +73 1 +74 0 +75 1 +76 0 +77 0 + + +-- !query 18 +select a+80, case when b<>0 or c<>0 then 1 else 0 end from t1 +-- !query 18 schema +struct<(a + 80):int,CASE WHEN ((NOT (b = 0)) OR (NOT (c = 0))) THEN 1 ELSE 0 END:int> +-- !query 18 output +81 0 +82 1 +83 1 +84 1 +85 0 +86 1 +87 0 + + +-- !query 19 +select a+90, case when not (b<>0 or c<>0) then 1 else 0 end from t1 +-- !query 19 schema +struct<(a + 90):int,CASE WHEN (NOT ((NOT (b = 0)) OR (NOT (c = 0)))) THEN 1 ELSE 0 END:int> +-- !query 19 output +91 1 +92 0 +93 0 +94 0 +95 0 +96 0 +97 0 + + +-- !query 20 +select count(*), count(b), sum(b), avg(b), min(b), max(b) from t1 +-- !query 20 schema +struct +-- !query 20 output +7 4 2 0.5 0 1 + + +-- !query 21 +select a+100 from t1 where b<10 +-- !query 21 schema +struct<(a + 100):int> +-- !query 21 output +101 +102 +103 +104 + + +-- !query 22 +select a+110 from t1 where not b>10 +-- !query 22 schema +struct<(a + 110):int> +-- !query 22 output +111 +112 +113 +114 + + +-- !query 23 +select a+120 from t1 where b<10 OR c=1 +-- !query 23 schema +struct<(a + 120):int> +-- !query 23 output +121 +122 +123 +124 +126 + + +-- !query 24 +select a+130 from t1 where b<10 AND c=1 +-- !query 24 schema +struct<(a + 130):int> +-- !query 24 output +132 +134 + + +-- !query 25 +select a+140 from t1 where not (b<10 AND c=1) +-- !query 25 schema +struct<(a + 140):int> +-- !query 25 output +141 +143 +145 + + +-- !query 26 +select a+150 from t1 where not (c=1 AND b<10) +-- !query 26 schema +struct<(a + 150):int> +-- !query 26 output +151 +153 +155 + + +-- !query 27 +drop table t1 +-- !query 27 schema +struct<> +-- !query 27 output + From bc1e101039ae3700eab42e633571256440a42b9d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 3 Nov 2017 23:35:57 -0700 Subject: [PATCH 172/712] [SPARK-22254][CORE] Fix the arrayMax in BufferHolder ## What changes were proposed in this pull request? This PR replaces the old the maximum array size (`Int.MaxValue`) with the new one (`ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH`). This PR also refactor the code to calculate the new array size to easily understand why we have to use `newSize - 2` for allocating a new array. ## How was this patch tested? Used the existing test Author: Kazuaki Ishizaki Closes #19650 from kiszk/SPARK-22254. --- .../spark/util/collection/CompactBuffer.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala index f5d2fa14e49cb..5d3693190cc1f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag +import org.apache.spark.unsafe.array.ByteArrayMethods + /** * An append-only buffer similar to ArrayBuffer, but more memory-efficient for small buffers. * ArrayBuffer always allocates an Object array to store the data, with 16 entries by default, @@ -126,16 +128,16 @@ private[spark] class CompactBuffer[T: ClassTag] extends Seq[T] with Serializable /** Increase our size to newSize and grow the backing array if needed. */ private def growToSize(newSize: Int): Unit = { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val arrayMax = Int.MaxValue - 8 - if (newSize < 0 || newSize - 2 > arrayMax) { + // since two fields are hold in element0 and element1, an array holds newSize - 2 elements + val newArraySize = newSize - 2 + val arrayMax = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + if (newSize < 0 || newArraySize > arrayMax) { throw new UnsupportedOperationException(s"Can't grow buffer past $arrayMax elements") } - val capacity = if (otherElements != null) otherElements.length + 2 else 2 - if (newSize > capacity) { + val capacity = if (otherElements != null) otherElements.length else 0 + if (newArraySize > capacity) { var newArrayLen = 8L - while (newSize - 2 > newArrayLen) { + while (newArraySize > newArrayLen) { newArrayLen *= 2 } if (newArrayLen > arrayMax) { From e7adb7d7a6d017bb95da566e76b39d9d96ab42c1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 4 Nov 2017 16:59:58 +0900 Subject: [PATCH 173/712] [SPARK-22437][PYSPARK] default mode for jdbc is wrongly set to None ## What changes were proposed in this pull request? When writing using jdbc with python currently we are wrongly assigning by default None as writing mode. This is due to wrongly calling mode on the `_jwrite` object instead of `self` and it causes an exception. ## How was this patch tested? manual tests Author: Marco Gaido Closes #19654 from mgaido91/SPARK-22437. --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f3092918abb54..3d87567ab673d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -915,7 +915,7 @@ def jdbc(self, url, table, mode=None, properties=None): jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.mode(mode).jdbc(url, table, jprop) + self.mode(mode)._jwrite.jdbc(url, table, jprop) def _test(): From 7a8412352e3aaf14527f97c82d0d62f9de39e753 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sat, 4 Nov 2017 11:51:10 +0000 Subject: [PATCH 174/712] [SPARK-22423][SQL] Scala test source files like TestHiveSingleton.scala should be in scala source root ## What changes were proposed in this pull request? Scala test source files like TestHiveSingleton.scala should be in scala source root ## How was this patch tested? Just move scala file from java directory to scala directory No new test case in this PR. ``` renamed: mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala -> mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala renamed: streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala -> streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala renamed: streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala -> streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala renamed: sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala ``` Author: xubo245 <601450868@qq.com> Closes #19639 from xubo245/scalaDirectory. --- .../org/apache/spark/ml/util/IdentifiableSuite.scala | 0 .../org/apache/spark/sql/hive/test/TestHiveSingleton.scala | 0 .../org/apache/spark/streaming/JavaTestUtils.scala | 0 .../streaming/api/java/JavaStreamingListenerWrapperSuite.scala | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename mllib/src/test/{java => scala}/org/apache/spark/ml/util/IdentifiableSuite.scala (100%) rename sql/hive/src/test/{java => scala}/org/apache/spark/sql/hive/test/TestHiveSingleton.scala (100%) rename streaming/src/test/{java => scala}/org/apache/spark/streaming/JavaTestUtils.scala (100%) rename streaming/src/test/{java => scala}/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala (100%) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala similarity index 100% rename from mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/util/IdentifiableSuite.scala diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala similarity index 100% rename from streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala rename to streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala similarity index 100% rename from streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala From 0c2aee69b0efeea5ce8d39c0564e9e4511faf387 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 4 Nov 2017 13:11:09 +0100 Subject: [PATCH 175/712] [SPARK-22410][SQL] Remove unnecessary output from BatchEvalPython's children plans ## What changes were proposed in this pull request? When we insert `BatchEvalPython` for Python UDFs into a query plan, if its child has some outputs that are not used by the original parent node, `BatchEvalPython` will still take those outputs and save into the queue. When the data for those outputs are big, it is easily to generate big spill on disk. For example, the following reproducible code is from the JIRA ticket. ```python from pyspark.sql.functions import * from pyspark.sql.types import * lines_of_file = [ "this is a line" for x in xrange(10000) ] file_obj = [ "this_is_a_foldername/this_is_a_filename", lines_of_file ] data = [ file_obj for x in xrange(5) ] small_df = spark.sparkContext.parallelize(data).map(lambda x : (x[0], x[1])).toDF(["file", "lines"]) exploded = small_df.select("file", explode("lines")) def split_key(s): return s.split("/")[1] split_key_udf = udf(split_key, StringType()) with_filename = exploded.withColumn("filename", split_key_udf("file")) with_filename.explain(True) ``` The physical plan before/after this change: Before: ``` *Project [file#0, col#5, pythonUDF0#14 AS filename#9] +- BatchEvalPython [split_key(file#0)], [file#0, lines#1, col#5, pythonUDF0#14] +- Generate explode(lines#1), true, false, [col#5] +- Scan ExistingRDD[file#0,lines#1] ``` After: ``` *Project [file#0, col#5, pythonUDF0#14 AS filename#9] +- BatchEvalPython [split_key(file#0)], [col#5, file#0, pythonUDF0#14] +- *Project [col#5, file#0] +- Generate explode(lines#1), true, false, [col#5] +- Scan ExistingRDD[file#0,lines#1] ``` Before this change, `lines#1` is a redundant input to `BatchEvalPython`. This patch removes it by adding a Project. ## How was this patch tested? Manually test. Author: Liang-Chi Hsieh Closes #19642 from viirya/SPARK-22410. --- .../sql/execution/python/ExtractPythonUDFs.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index d6825369f7378..e15e760136e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -127,8 +127,19 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // If there aren't any, we are done. plan } else { + val inputsForPlan = plan.references ++ plan.outputSet + val prunedChildren = plan.children.map { child => + val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq + if (allNeededOutput.length != child.output.length) { + ProjectExec(allNeededOutput, child) + } else { + child + } + } + val planWithNewChildren = plan.withNewChildren(prunedChildren) + val attributeMap = mutable.HashMap[PythonUDF, Expression]() - val splitFilter = trySplitFilter(plan) + val splitFilter = trySplitFilter(planWithNewChildren) // Rewrite the child that has the input required for the UDF val newChildren = splitFilter.children.map { child => // Pick the UDF we are going to evaluate From f7f4e9c2db405b887832fcb592cd4522795d00ca Mon Sep 17 00:00:00 2001 From: Vinitha Gankidi Date: Sat, 4 Nov 2017 11:09:47 -0700 Subject: [PATCH 176/712] [SPARK-22412][SQL] Fix incorrect comment in DataSourceScanExec ## What changes were proposed in this pull request? Next fit decreasing bin packing algorithm is used to combine splits in DataSourceScanExec but the comment incorrectly states that first fit decreasing algorithm is used. The current implementation doesn't go back to a previously used bin other than the bin that the last element was put into. Author: Vinitha Gankidi Closes #19634 from vgankidi/SPARK-22412. --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e9f65031143b7..a607ec0bf8c9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -469,7 +469,7 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "First Fit Decreasing" (FFD) + // Assign files to partitions using "Next Fit Decreasing" splitFiles.foreach { file => if (currentSize + file.length > maxSplitBytes) { closePartition() From 6c6626614e59b2e8d66ca853a74638d3d6267d73 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Sat, 4 Nov 2017 22:47:25 -0700 Subject: [PATCH 177/712] [SPARK-22211][SQL] Remove incorrect FOJ limit pushdown ## What changes were proposed in this pull request? It's not safe in all cases to push down a LIMIT below a FULL OUTER JOIN. If the limit is pushed to one side of the FOJ, the physical join operator can not tell if a row in the non-limited side would have a match in the other side. *If* the join operator guarantees that unmatched tuples from the limited side are emitted before any unmatched tuples from the other side, pushing down the limit is safe. But this is impractical for some join implementations, e.g. SortMergeJoin. For now, disable limit pushdown through a FULL OUTER JOIN, and we can evaluate whether a more complicated solution is necessary in the future. ## How was this patch tested? Ran org.apache.spark.sql.* tests. Altered full outer join tests in LimitPushdownSuite. Author: Henry Robinson Closes #19647 from henryr/spark-22211. --- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++----------- .../optimizer/LimitPushdownSuite.scala | 33 +++++++++---------- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3273a61dc7b35..3a3ccd5ff5e60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -332,12 +332,11 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, Union(children)) => LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the - // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side - // because we need to ensure that rows from the limited side still have an opportunity to match - // against all candidates from the non-limited side. We also need to ensure that this limit - // pushdown rule will not eventually introduce limits on both sides if it is applied multiple - // times. Therefore: + // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to + // the left and right sides, respectively. It's not safe to push limits below FULL OUTER + // JOIN in the general case without a more invasive rewrite. + // We also need to ensure that this limit pushdown rule will not eventually introduce limits + // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. // - If neither side is limited, limit the side that is estimated to be bigger. @@ -345,19 +344,6 @@ object LimitPushDown extends Rule[LogicalPlan] { val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) - case FullOuter => - (left.maxRows, right.maxRows) match { - case (None, None) => - if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLocalLimit(exp, left)) - } else { - join.copy(right = maybePushLocalLimit(exp, right)) - } - case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) - - } case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index f50e2e86516f0..cc98d2350c777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -113,35 +113,34 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and both sides have same statistics") { assert(x.stats.sizeInBytes === y.stats.sizeInBytes) - val originalQuery = x.join(y, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.join(y, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) - val originalQuery = xBig.join(y, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = xBig.join(y, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) - val originalQuery = x.join(yBig, FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.join(yBig, FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } test("full outer join where both sides are limited") { - val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze - comparePlans(optimized, correctAnswer) + val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1).analyze + val optimized = Optimize.execute(originalQuery) + // No pushdown for FULL OUTER JOINS. + comparePlans(optimized, originalQuery) } } - From 3bba8621cf0a97f5c3134c9a160b1c8c5e97ba97 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 4 Nov 2017 22:57:12 -0700 Subject: [PATCH 178/712] [SPARK-22378][SQL] Eliminate redundant null check in generated code for extracting an element from complex types ## What changes were proposed in this pull request? This PR eliminates redundant null check in generated code for extracting an element from complex types `GetArrayItem`, `GetMapValue`, and `GetArrayStructFields`. Since these code generation does not take care of `nullable` in `DataType` such as `ArrayType`, the generated code always has `isNullAt(index)`. This PR avoids to generate `isNullAt(index)` if `nullable` is false in `DataType`. Example ``` val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false)) checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1) ``` Before this PR ``` /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ boolean isNull = true; /* 040 */ int value = -1; /* 041 */ /* 042 */ /* 043 */ /* 044 */ isNull = false; // resultCode could change nullability. /* 045 */ /* 046 */ final int index = (int) 0; /* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index < 0 || ((ArrayData) references[0]).isNullAt(index)) { /* 048 */ isNull = true; /* 049 */ } else { /* 050 */ value = ((ArrayData) references[0]).getInt(index); /* 051 */ } /* 052 */ isNull_0 = isNull; /* 053 */ value_0 = value; /* 054 */ /* 055 */ // copy all the results into MutableRow /* 056 */ /* 057 */ if (!isNull_0) { /* 058 */ mutableRow.setInt(0, value_0); /* 059 */ } else { /* 060 */ mutableRow.setNullAt(0); /* 061 */ } /* 062 */ /* 063 */ return mutableRow; /* 064 */ } ``` After this PR (Line 47 is changed) ``` /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ boolean isNull = true; /* 040 */ int value = -1; /* 041 */ /* 042 */ /* 043 */ /* 044 */ isNull = false; // resultCode could change nullability. /* 045 */ /* 046 */ final int index = (int) 0; /* 047 */ if (index >= ((ArrayData) references[0]).numElements() || index < 0) { /* 048 */ isNull = true; /* 049 */ } else { /* 050 */ value = ((ArrayData) references[0]).getInt(index); /* 051 */ } /* 052 */ isNull_0 = isNull; /* 053 */ value_0 = value; /* 054 */ /* 055 */ // copy all the results into MutableRow /* 056 */ /* 057 */ if (!isNull_0) { /* 058 */ mutableRow.setInt(0, value_0); /* 059 */ } else { /* 060 */ mutableRow.setNullAt(0); /* 061 */ } /* 062 */ /* 063 */ return mutableRow; /* 064 */ } ``` ## How was this patch tested? Added test cases into `ComplexTypeSuite` Author: Kazuaki Ishizaki Closes #19598 from kiszk/SPARK-22378. --- .../expressions/complexTypeExtractors.scala | 28 +++++++++++++++---- .../expressions/ComplexTypeSuite.scala | 11 ++++++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index ef88cfb543ebb..7e53ca3908905 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -186,6 +186,16 @@ case class GetArrayStructFields( val values = ctx.freshName("values") val j = ctx.freshName("j") val row = ctx.freshName("row") + val nullSafeEval = if (field.nullable) { + s""" + if ($row.isNullAt($ordinal)) { + $values[$j] = null; + } else + """ + } else { + "" + } + s""" final int $n = $eval.numElements(); final Object[] $values = new Object[$n]; @@ -194,9 +204,7 @@ case class GetArrayStructFields( $values[$j] = null; } else { final InternalRow $row = $eval.getStruct($j, $numFields); - if ($row.isNullAt($ordinal)) { - $values[$j] = null; - } else { + $nullSafeEval { $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } @@ -242,9 +250,14 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") + val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { + s" || $eval1.isNullAt($index)" + } else { + "" + } s""" final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { + if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; @@ -309,6 +322,11 @@ case class GetMapValue(child: Expression, key: Expression) val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") + val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + s" || $values.isNullAt($index)" + } else { + "" + } nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -326,7 +344,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!$found || $values.isNullAt($index)) { + if (!$found$nullCheck) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(values, dataType, index)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5f8a8f44d48e6..b0eaad1c80f89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -51,6 +51,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(array, nullInt), null) checkEvaluation(GetArrayItem(nullArray, nullInt), null) + val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false)) + checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1) + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } @@ -66,6 +69,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetMapValue(nullMap, nullString), null) checkEvaluation(GetMapValue(map, nullString), null) + val nonNullMap = Literal.create(Map("a" -> 1), MapType(StringType, IntegerType, false)) + checkEvaluation(GetMapValue(nonNullMap, Literal("a")), 1) + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } @@ -101,9 +107,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayStructFields") { - val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: Nil)) + val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) - val nullArrayStruct = Literal.create(null, typeAS) + val nullArrayStruct = Literal.create(null, typeNullAS) def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { expr.dataType match { From 572284c5b08515901267a37adef4f8e55df3780e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 4 Nov 2017 23:07:24 -0700 Subject: [PATCH 179/712] [SPARK-22443][SQL] add implementation of quoteIdentifier, getTableExistsQuery and getSchemaQuery in AggregatedDialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … ## What changes were proposed in this pull request? override JDBCDialects methods quoteIdentifier, getTableExistsQuery and getSchemaQuery in AggregatedDialect ## How was this patch tested? Test the new implementation in JDBCSuite test("Aggregated dialects") Author: Huaxin Gao Closes #19658 from huaxingao/spark-22443. --- .../apache/spark/sql/jdbc/AggregatedDialect.scala | 12 ++++++++++++ .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 1419d69f983ab..f3bfea5f6bfc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -42,6 +42,18 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.flatMap(_.getJDBCType(dt)).headOption } + override def quoteIdentifier(colName: String): String = { + dialects.head.quoteIdentifier(colName) + } + + override def getTableExistsQuery(table: String): String = { + dialects.head.getTableExistsQuery(table) + } + + override def getSchemaQuery(table: String): String = { + dialects.head.getSchemaQuery(table) + } + override def isCascadingTruncateTable(): Option[Boolean] = { // If any dialect claims cascading truncate, this dialect is also cascading truncate. // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 167b3e0190026..88a5f618d604d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -740,6 +740,15 @@ class JDBCSuite extends SparkFunSuite } else { None } + override def quoteIdentifier(colName: String): String = { + s"My $colName quoteIdentifier" + } + override def getTableExistsQuery(table: String): String = { + s"My $table Table" + } + override def getSchemaQuery(table: String): String = { + s"My $table Schema" + } override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) @@ -747,6 +756,9 @@ class JDBCSuite extends SparkFunSuite assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) assert(agg.isCascadingTruncateTable() === Some(true)) + assert(agg.quoteIdentifier ("Dummy") === "My Dummy quoteIdentifier") + assert(agg.getTableExistsQuery ("Dummy") === "My Dummy Table") + assert(agg.getSchemaQuery ("Dummy") === "My Dummy Schema") } test("Aggregated dialects: isCascadingTruncateTable") { From fe258a7963361c1f31bc3dc3a2a2ee4a5834bb58 Mon Sep 17 00:00:00 2001 From: Tristan Stevens Date: Sun, 5 Nov 2017 09:10:40 +0000 Subject: [PATCH 180/712] [SPARK-22429][STREAMING] Streaming checkpointing code does not retry after failure ## What changes were proposed in this pull request? SPARK-14930/SPARK-13693 put in a change to set the fs object to null after a failure, however the retry loop does not include initialization. Moved fs initialization inside the retry while loop to aid recoverability. ## How was this patch tested? Passes all existing unit tests. Author: Tristan Stevens Closes #19645 from tmgstevens/SPARK-22429. --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 9ebb91b8cab3c..3cfbcedd519d6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -211,9 +211,6 @@ class CheckpointWriter( if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) { latestCheckpointTime = checkpointTime } - if (fs == null) { - fs = new Path(checkpointDir).getFileSystem(hadoopConf) - } var attempts = 0 val startTime = System.currentTimeMillis() val tempFile = new Path(checkpointDir, "temp") @@ -233,7 +230,9 @@ class CheckpointWriter( attempts += 1 try { logInfo(s"Saving checkpoint for time $checkpointTime to file '$checkpointFile'") - + if (fs == null) { + fs = new Path(checkpointDir).getFileSystem(hadoopConf) + } // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) From db389f71972754697ebaa89b731d117d23367dd1 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 5 Nov 2017 20:10:15 -0800 Subject: [PATCH 181/712] [SPARK-21625][DOC] Add incompatible Hive UDF describe to DOC ## What changes were proposed in this pull request? Add incompatible Hive UDF describe to DOC. ## How was this patch tested? N/A Author: Yuming Wang Closes #18833 from wangyum/SPARK-21625. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ce377875ff2b1..686fcb159d09d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1958,6 +1958,14 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. + +### Incompatible Hive UDF + +Below are the scenarios in which Hive and Spark generate different results: + +* `SQRT(n)` If n < 0, Hive returns null, Spark SQL returns NaN. +* `ACOS(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. +* `ASIN(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN. # Reference From 4bacddb602e19fcd4e1ec75a7b10bed524e6989a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 5 Nov 2017 21:21:12 -0800 Subject: [PATCH 182/712] [SPARK-7146][ML] Expose the common params as a DeveloperAPI for other ML developers ## What changes were proposed in this pull request? Expose the common params from Spark ML as a Developer API. ## How was this patch tested? Existing tests. Author: Holden Karau Author: Holden Karau Closes #18699 from holdenk/SPARK-7146-ml-shared-params-developer-api. --- .../ml/param/shared/SharedParamsCodeGen.scala | 7 +- .../spark/ml/param/shared/sharedParams.scala | 145 ++++++++++++------ 2 files changed, 102 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe8361749..a932d28fadbd8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -177,9 +177,11 @@ private[shared] object SharedParamsCodeGen { s""" |/** - | * Trait for shared param $name$defaultValueDoc. + | * Trait for shared param $name$defaultValueDoc. This trait may be changed or + | * removed between minor versions. | */ - |private[ml] trait Has$Name extends Params { + |@DeveloperApi + |trait Has$Name extends Params { | | /** | * Param for $htmlCompliantDoc. @@ -215,6 +217,7 @@ private[shared] object SharedParamsCodeGen { | |package org.apache.spark.ml.param.shared | + |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a084..e6bdf5236e72d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.param.shared +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -24,9 +25,11 @@ import org.apache.spark.ml.param._ // scalastyle:off /** - * Trait for shared param regParam. + * Trait for shared param regParam. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasRegParam extends Params { +@DeveloperApi +trait HasRegParam extends Params { /** * Param for regularization parameter (>= 0). @@ -39,9 +42,11 @@ private[ml] trait HasRegParam extends Params { } /** - * Trait for shared param maxIter. + * Trait for shared param maxIter. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasMaxIter extends Params { +@DeveloperApi +trait HasMaxIter extends Params { /** * Param for maximum number of iterations (>= 0). @@ -54,9 +59,11 @@ private[ml] trait HasMaxIter extends Params { } /** - * Trait for shared param featuresCol (default: "features"). + * Trait for shared param featuresCol (default: "features"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasFeaturesCol extends Params { +@DeveloperApi +trait HasFeaturesCol extends Params { /** * Param for features column name. @@ -71,9 +78,11 @@ private[ml] trait HasFeaturesCol extends Params { } /** - * Trait for shared param labelCol (default: "label"). + * Trait for shared param labelCol (default: "label"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasLabelCol extends Params { +@DeveloperApi +trait HasLabelCol extends Params { /** * Param for label column name. @@ -88,9 +97,11 @@ private[ml] trait HasLabelCol extends Params { } /** - * Trait for shared param predictionCol (default: "prediction"). + * Trait for shared param predictionCol (default: "prediction"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasPredictionCol extends Params { +@DeveloperApi +trait HasPredictionCol extends Params { /** * Param for prediction column name. @@ -105,9 +116,11 @@ private[ml] trait HasPredictionCol extends Params { } /** - * Trait for shared param rawPredictionCol (default: "rawPrediction"). + * Trait for shared param rawPredictionCol (default: "rawPrediction"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasRawPredictionCol extends Params { +@DeveloperApi +trait HasRawPredictionCol extends Params { /** * Param for raw prediction (a.k.a. confidence) column name. @@ -122,9 +135,11 @@ private[ml] trait HasRawPredictionCol extends Params { } /** - * Trait for shared param probabilityCol (default: "probability"). + * Trait for shared param probabilityCol (default: "probability"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasProbabilityCol extends Params { +@DeveloperApi +trait HasProbabilityCol extends Params { /** * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. @@ -139,9 +154,11 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param varianceCol. + * Trait for shared param varianceCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasVarianceCol extends Params { +@DeveloperApi +trait HasVarianceCol extends Params { /** * Param for Column name for the biased sample variance of prediction. @@ -154,9 +171,11 @@ private[ml] trait HasVarianceCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasThreshold extends Params { +@DeveloperApi +trait HasThreshold extends Params { /** * Param for threshold in binary classification prediction, in range [0, 1]. @@ -169,9 +188,11 @@ private[ml] trait HasThreshold extends Params { } /** - * Trait for shared param thresholds. + * Trait for shared param thresholds. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasThresholds extends Params { +@DeveloperApi +trait HasThresholds extends Params { /** * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. @@ -184,9 +205,11 @@ private[ml] trait HasThresholds extends Params { } /** - * Trait for shared param inputCol. + * Trait for shared param inputCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasInputCol extends Params { +@DeveloperApi +trait HasInputCol extends Params { /** * Param for input column name. @@ -199,9 +222,11 @@ private[ml] trait HasInputCol extends Params { } /** - * Trait for shared param inputCols. + * Trait for shared param inputCols. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasInputCols extends Params { +@DeveloperApi +trait HasInputCols extends Params { /** * Param for input column names. @@ -214,9 +239,11 @@ private[ml] trait HasInputCols extends Params { } /** - * Trait for shared param outputCol (default: uid + "__output"). + * Trait for shared param outputCol (default: uid + "__output"). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasOutputCol extends Params { +@DeveloperApi +trait HasOutputCol extends Params { /** * Param for output column name. @@ -231,9 +258,11 @@ private[ml] trait HasOutputCol extends Params { } /** - * Trait for shared param checkpointInterval. + * Trait for shared param checkpointInterval. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasCheckpointInterval extends Params { +@DeveloperApi +trait HasCheckpointInterval extends Params { /** * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -246,9 +275,11 @@ private[ml] trait HasCheckpointInterval extends Params { } /** - * Trait for shared param fitIntercept (default: true). + * Trait for shared param fitIntercept (default: true). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasFitIntercept extends Params { +@DeveloperApi +trait HasFitIntercept extends Params { /** * Param for whether to fit an intercept term. @@ -263,9 +294,11 @@ private[ml] trait HasFitIntercept extends Params { } /** - * Trait for shared param handleInvalid. + * Trait for shared param handleInvalid. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasHandleInvalid extends Params { +@DeveloperApi +trait HasHandleInvalid extends Params { /** * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. @@ -278,9 +311,11 @@ private[ml] trait HasHandleInvalid extends Params { } /** - * Trait for shared param standardization (default: true). + * Trait for shared param standardization (default: true). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasStandardization extends Params { +@DeveloperApi +trait HasStandardization extends Params { /** * Param for whether to standardize the training features before fitting the model. @@ -295,9 +330,11 @@ private[ml] trait HasStandardization extends Params { } /** - * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasSeed extends Params { +@DeveloperApi +trait HasSeed extends Params { /** * Param for random seed. @@ -312,9 +349,11 @@ private[ml] trait HasSeed extends Params { } /** - * Trait for shared param elasticNetParam. + * Trait for shared param elasticNetParam. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasElasticNetParam extends Params { +@DeveloperApi +trait HasElasticNetParam extends Params { /** * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -327,9 +366,11 @@ private[ml] trait HasElasticNetParam extends Params { } /** - * Trait for shared param tol. + * Trait for shared param tol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasTol extends Params { +@DeveloperApi +trait HasTol extends Params { /** * Param for the convergence tolerance for iterative algorithms (>= 0). @@ -342,9 +383,11 @@ private[ml] trait HasTol extends Params { } /** - * Trait for shared param stepSize. + * Trait for shared param stepSize. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasStepSize extends Params { +@DeveloperApi +trait HasStepSize extends Params { /** * Param for Step size to be used for each iteration of optimization (> 0). @@ -357,9 +400,11 @@ private[ml] trait HasStepSize extends Params { } /** - * Trait for shared param weightCol. + * Trait for shared param weightCol. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasWeightCol extends Params { +@DeveloperApi +trait HasWeightCol extends Params { /** * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. @@ -372,9 +417,11 @@ private[ml] trait HasWeightCol extends Params { } /** - * Trait for shared param solver. + * Trait for shared param solver. This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasSolver extends Params { +@DeveloperApi +trait HasSolver extends Params { /** * Param for the solver algorithm for optimization. @@ -387,9 +434,11 @@ private[ml] trait HasSolver extends Params { } /** - * Trait for shared param aggregationDepth (default: 2). + * Trait for shared param aggregationDepth (default: 2). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasAggregationDepth extends Params { +@DeveloperApi +trait HasAggregationDepth extends Params { /** * Param for suggested depth for treeAggregate (>= 2). From 472db58cb19bbd3025eabbd185d920aab0ebb4da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Nov 2017 15:10:44 +0100 Subject: [PATCH 183/712] [SPARK-22445][SQL] move CodegenContext.copyResult to CodegenSupport ## What changes were proposed in this pull request? `CodegenContext.copyResult` is kind of a global status for whole stage codegen. But the tricky part is, it is only used to transfer an information from child to parent when calling the `consume` chain. We have to be super careful in `produce`/`consume`, to set it to true when producing multiple result rows, and set it to false in operators that start new pipeline(like sort). This PR moves the `copyResult` to `CodegenSupport`, and call it at `WholeStageCodegenExec`. This is much easier to reason about. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19656 from cloud-fan/whole-sage. --- .../expressions/codegen/CodeGenerator.scala | 10 ------ .../sql/execution/ColumnarBatchScan.scala | 2 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../apache/spark/sql/execution/SortExec.scala | 14 ++++---- .../sql/execution/WholeStageCodegenExec.scala | 35 ++++++++++++++----- .../aggregate/HashAggregateExec.scala | 14 ++++---- .../execution/basicPhysicalOperators.scala | 5 +-- .../joins/BroadcastHashJoinExec.scala | 16 +++++++-- .../execution/joins/SortMergeJoinExec.scala | 3 +- 10 files changed, 66 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 58738b52b299f..98eda2a1ba92c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -139,16 +139,6 @@ class CodegenContext { */ var currentVars: Seq[ExprCode] = null - /** - * Whether should we copy the result rows or not. - * - * If any operator inside WholeStageCodegen generate multiple rows from a single row (for - * example, Join), this should be true. - * - * If an operator starts a new pipeline, this should be reset to false before calling `consume()`. - */ - var copyResult: Boolean = false - /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index eb01e126bcbef..1925bad8c3545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -115,7 +115,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $idx = $rowidx + 1; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index d5603b3b00914..33849f4389b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -93,6 +93,8 @@ case class ExpandExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { /* * When the projections list looks like: @@ -187,7 +189,6 @@ case class ExpandExec( val i = ctx.freshName("i") // these column have to declared before the loop. val evaluate = evaluateVariables(outputColumns) - ctx.copyResult = true s""" |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 65ca37491b6a1..c142d3b5ed4f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -132,9 +132,10 @@ case class GenerateExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { ctx.currentVars = input - ctx.copyResult = true // Add input rows to the values when we are joining val values = if (join) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ff71fd4dc7bb7..21765cdbd94cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -124,6 +124,14 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ + // The result rows come from the sort buffer, so this operator doesn't need to copy its result + // even if its child does. + override def needCopyResult: Boolean = false + + // Sort operator always consumes all the input rows before outputting any result, so we don't need + // a stop check before sorting. + override def needStopCheck: Boolean = false + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -148,10 +156,6 @@ case class SortExec( | } """.stripMargin.trim) - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - val outputRow = ctx.freshName("outputRow") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") @@ -177,8 +181,6 @@ case class SortExec( """.stripMargin.trim } - protected override val shouldStopRequired = false - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 286cb3bb0767c..16b5706c03bf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -213,19 +213,32 @@ trait CodegenSupport extends SparkPlan { } /** - * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. - * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + * Whether or not the result rows of this operator should be copied before putting into a buffer. + * + * If any operator inside WholeStageCodegen generate multiple rows from a single row (for + * example, Join), this should be true. + * + * If an operator starts a new pipeline, this should be false. */ - def isShouldStopRequired: Boolean = { - return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + def needCopyResult: Boolean = { + if (children.isEmpty) { + false + } else if (children.length == 1) { + children.head.asInstanceOf[CodegenSupport].needCopyResult + } else { + throw new UnsupportedOperationException + } } /** - * Set to false if this plan consumes all rows produced by children but doesn't output row - * to buffer by calling append(), so the children don't require shouldStop() - * in the loop of producing rows. + * Whether or not the children of this operator should generate a stop check when consuming input + * rows. This is used to suppress shouldStop() in a loop of WholeStageCodegen. + * + * This should be false if an operator starts a new pipeline, which means it consumes all rows + * produced by children but doesn't output row to buffer by calling append(), so the children + * don't require shouldStop() in the loop of producing rows. */ - protected def shouldStopRequired: Boolean = true + def needStopCheck: Boolean = parent.needStopCheck } @@ -278,6 +291,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "") } + + override def needCopyResult: Boolean = false } object WholeStageCodegenExec { @@ -467,7 +482,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val doCopy = if (ctx.copyResult) { + val doCopy = if (needCopyResult) { ".copy()" } else { "" @@ -487,6 +502,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "*") } + + override def needStopCheck: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 43e5ff89afee6..2a208a2722550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -149,6 +149,14 @@ case class HashAggregateExec( child.asInstanceOf[CodegenSupport].inputRDDs() } + // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this + // operator doesn't need to copy its result even if its child does. + override def needCopyResult: Boolean = false + + // Aggregate operator always consumes all the input rows before outputting any result, so we + // don't need a stop check before aggregating. + override def needStopCheck: Boolean = false + protected override def doProduce(ctx: CodegenContext): String = { if (groupingExpressions.isEmpty) { doProduceWithoutKeys(ctx) @@ -246,8 +254,6 @@ case class HashAggregateExec( """.stripMargin } - protected override val shouldStopRequired = false - private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -651,10 +657,6 @@ case class HashAggregateExec( val outputFunc = generateResultFunction(ctx) val numOutput = metricTerm(ctx, "numOutputRows") - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - def outputFromGeneratedMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e58c3cec2df15..3c7daa0a45844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,6 +279,8 @@ case class SampleExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = withReplacement + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") val sampler = ctx.freshName("sampler") @@ -286,7 +288,6 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - ctx.copyResult = true val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" @@ -450,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $number = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index b09da9bdacb99..837b8525fed55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,6 +76,20 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } + override def needCopyResult: Boolean = joinType match { + case _: InnerLike | LeftOuter | RightOuter => + // For inner and outer joins, one row from the streamed side may produce multiple result rows, + // if the build side has duplicated keys. Then we need to copy the result rows before putting + // them in a buffer, because these result rows share one UnsafeRow instance. Note that here + // we wait for the broadcast to be finished, which is a no-op because it's already finished + // when we wait it in `doProduce`. + !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + + // Other joins types(semi, anti, existence) can at most produce one result row for one input + // row from the streamed side, so no need to copy the result rows. + case _ => false + } + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } @@ -237,7 +251,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName s""" @@ -310,7 +323,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e82..cf7885f80d9fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -569,8 +569,9 @@ case class SortMergeJoinExec( } } + override def needCopyResult: Boolean = true + override def doProduce(ctx: CodegenContext): String = { - ctx.copyResult = true val leftInput = ctx.freshName("leftInput") ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") From c7f38e5adb88d43ef60662c5d6ff4e7a95bff580 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 6 Nov 2017 08:45:40 -0600 Subject: [PATCH 184/712] [SPARK-20644][core] Initial ground work for kvstore UI backend. There are two somewhat unrelated things going on in this patch, but both are meant to make integration of individual UI pages later on much easier. The first part is some tweaking of the code in the listener so that it does less updates of the kvstore for data that changes fast; for example, it avoids writing changes down to the store for every task-related event, since those can arrive very quickly at times. Instead, for these kinds of events, it chooses to only flush things if a certain interval has passed. The interval is based on how often the current spark-shell code updates the progress bar for jobs, so that users can get reasonably accurate data. The code also delays as much as possible hitting the underlying kvstore when replaying apps in the history server. This is to avoid unnecessary writes to disk. The second set of changes prepare the history server and SparkUI for integrating with the kvstore. A new class, AppStatusStore, is used for translating between the stored data and the types used in the UI / API. The SHS now populates a kvstore with data loaded from event logs when an application UI is requested. Because this store can hold references to disk-based resources, the code was modified to retrieve data from the store under a read lock. This allows the SHS to detect when the store is still being used, and only update it (e.g. because an updated event log was detected) when there is no other thread using the store. This change ended up creating a lot of churn in the ApplicationCache code, which was cleaned up a lot in the process. I also removed some metrics which don't make too much sense with the new code. Tested with existing and added unit tests, and by making sure the SHS still works on a real cluster. Author: Marcelo Vanzin Closes #19582 from vanzin/SPARK-20644. --- .../scala/org/apache/spark/SparkContext.scala | 17 +- .../deploy/history/ApplicationCache.scala | 441 ++++-------------- .../history/ApplicationHistoryProvider.scala | 65 ++- .../deploy/history/FsHistoryProvider.scala | 275 +++++++---- .../spark/deploy/history/HistoryServer.scala | 18 +- .../scheduler/ApplicationEventListener.scala | 67 --- .../spark/status/AppStatusListener.scala | 113 +++-- .../apache/spark/status/AppStatusStore.scala | 239 ++++++++++ .../org/apache/spark/status/LiveEntity.scala | 7 +- .../spark/status/api/v1/ApiRootResource.scala | 19 +- .../org/apache/spark/status/config.scala | 30 ++ .../org/apache/spark/status/storeTypes.scala | 26 +- .../scala/org/apache/spark/ui/SparkUI.scala | 73 +-- .../history/ApplicationCacheSuite.scala | 194 ++------ .../history/FsHistoryProviderSuite.scala | 40 +- .../deploy/history/HistoryServerSuite.scala | 28 +- .../spark/status/AppStatusListenerSuite.scala | 19 +- project/MimaExcludes.scala | 2 + 18 files changed, 878 insertions(+), 795 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusStore.scala create mode 100644 core/src/main/scala/org/apache/spark/status/config.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7dd635ad4c96..e5aaaf6c155eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.status.AppStatusStore import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _jars: Seq[String] = _ private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ + private var _statusStore: AppStatusStore = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -422,6 +424,10 @@ class SparkContext(config: SparkConf) extends Logging { _jobProgressListener = new JobProgressListener(_conf) listenerBus.addToStatusQueue(jobProgressListener) + // Initialize the app status store and listener before SparkEnv is created so that it gets + // all events. + _statusStore = AppStatusStore.createLiveStore(conf, listenerBus) + // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) @@ -443,8 +449,12 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, - _env.securityManager, appName, startTime = startTime)) + Some(SparkUI.create(Some(this), _statusStore, _conf, + l => listenerBus.addToStatusQueue(l), + _env.securityManager, + appName, + "", + startTime)) } else { // For tests, do not enable the UI None @@ -1940,6 +1950,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + if (_statusStore != null) { + _statusStore.close() + } // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this // `SparkContext` is stopped. localProperties.remove() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala index a370526c46f3d..8c63fa65b40fd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala @@ -18,14 +18,15 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.concurrent.ExecutionException import javax.servlet.{DispatcherType, Filter, FilterChain, FilterConfig, ServletException, ServletRequest, ServletResponse} import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.codahale.metrics.{Counter, MetricRegistry, Timer} import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalListener, RemovalNotification} +import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.FilterHolder import org.apache.spark.internal.Logging @@ -34,18 +35,11 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Clock /** - * Cache for applications. + * Cache for application UIs. * - * Completed applications are cached for as long as there is capacity for them. - * Incompleted applications have their update time checked on every - * retrieval; if the cached entry is out of date, it is refreshed. + * Applications are cached for as long as there is capacity for them. See [[LoadedAppUI]] for a + * discussion of the UI lifecycle. * - * @note there must be only one instance of [[ApplicationCache]] in a - * JVM at a time. This is because a static field in [[ApplicationCacheCheckFilterRelay]] - * keeps a reference to the cache so that HTTP requests on the attempt-specific web UIs - * can probe the current cache to see if the attempts have changed. - * - * Creating multiple instances will break this routing. * @param operations implementation of record access operations * @param retainedApplications number of retained applications * @param clock time source @@ -55,9 +49,6 @@ private[history] class ApplicationCache( val retainedApplications: Int, val clock: Clock) extends Logging { - /** - * Services the load request from the cache. - */ private val appLoader = new CacheLoader[CacheKey, CacheEntry] { /** the cache key doesn't match a cached entry, or the entry is out-of-date, so load it. */ @@ -67,9 +58,6 @@ private[history] class ApplicationCache( } - /** - * Handler for callbacks from the cache of entry removal. - */ private val removalListener = new RemovalListener[CacheKey, CacheEntry] { /** @@ -80,16 +68,11 @@ private[history] class ApplicationCache( metrics.evictionCount.inc() val key = rm.getKey logDebug(s"Evicting entry ${key}") - operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().ui) + operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().loadedUI.ui) } } - /** - * The cache of applications. - * - * Tagged as `protected` so as to allow subclasses in tests to access it directly - */ - protected val appCache: LoadingCache[CacheKey, CacheEntry] = { + private val appCache: LoadingCache[CacheKey, CacheEntry] = { CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(removalListener) @@ -101,151 +84,51 @@ private[history] class ApplicationCache( */ val metrics = new CacheMetrics("history.cache") - init() - - /** - * Perform any startup operations. - * - * This includes declaring this instance as the cache to use in the - * [[ApplicationCacheCheckFilterRelay]]. - */ - private def init(): Unit = { - ApplicationCacheCheckFilterRelay.setApplicationCache(this) - } - - /** - * Stop the cache. - * This will reset the relay in [[ApplicationCacheCheckFilterRelay]]. - */ - def stop(): Unit = { - ApplicationCacheCheckFilterRelay.resetApplicationCache() - } - - /** - * Get an entry. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def get(appAndAttempt: String): SparkUI = { - val parts = splitAppAndAttemptKey(appAndAttempt) - get(parts._1, parts._2) - } - - /** - * Get the Spark UI, converting a lookup failure from an exception to `None`. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def getSparkUI(appAndAttempt: String): Option[SparkUI] = { + def get(appId: String, attemptId: Option[String] = None): CacheEntry = { try { - val ui = get(appAndAttempt) - Some(ui) + appCache.get(new CacheKey(appId, attemptId)) } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - None - case cause: Exception => throw cause - } + case e @ (_: ExecutionException | _: UncheckedExecutionException) => + throw Option(e.getCause()).getOrElse(e) } } /** - * Get the associated spark UI. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the entry + * Run a closure while holding an application's UI read lock. This prevents the history server + * from closing the UI data store while it's being used. */ - def get(appId: String, attemptId: Option[String]): SparkUI = { - lookupAndUpdate(appId, attemptId)._1.ui - } + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + var entry = get(appId, attemptId) - /** - * Look up the entry; update it if needed. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the underlying cache entry -which can have its timestamp changed, and a flag to - * indicate that the entry has changed - */ - private def lookupAndUpdate(appId: String, attemptId: Option[String]): (CacheEntry, Boolean) = { - metrics.lookupCount.inc() - val cacheKey = CacheKey(appId, attemptId) - var entry = appCache.getIfPresent(cacheKey) - var updated = false - if (entry == null) { - // no entry, so fetch without any post-fetch probes for out-of-dateness - // this will trigger a callback to loadApplicationEntry() - entry = appCache.get(cacheKey) - } else if (!entry.completed) { - val now = clock.getTimeMillis() - log.debug(s"Probing at time $now for updated application $cacheKey -> $entry") - metrics.updateProbeCount.inc() - updated = time(metrics.updateProbeTimer) { - entry.updateProbe() + // If the entry exists, we need to make sure we run the closure with a valid entry. So + // we need to re-try until we can lock a valid entry for read. + entry.loadedUI.lock.readLock().lock() + try { + while (!entry.loadedUI.valid) { + entry.loadedUI.lock.readLock().unlock() + entry = null + try { + invalidate(new CacheKey(appId, attemptId)) + entry = get(appId, attemptId) + metrics.loadCount.inc() + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().lock() + } + } } - if (updated) { - logDebug(s"refreshing $cacheKey") - metrics.updateTriggeredCount.inc() - appCache.refresh(cacheKey) - // and repeat the lookup - entry = appCache.get(cacheKey) - } else { - // update the probe timestamp to the current time - entry.probeTime = now + + fn(entry.loadedUI.ui) + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().unlock() } } - (entry, updated) } - /** - * This method is visible for testing. - * - * It looks up the cached entry *and returns a clone of it*. - * This ensures that the cached entries never leak - * @param appId application ID - * @param attemptId optional attempt ID - * @return a new entry with shared SparkUI, but copies of the other fields. - */ - def lookupCacheEntry(appId: String, attemptId: Option[String]): CacheEntry = { - val entry = lookupAndUpdate(appId, attemptId)._1 - new CacheEntry(entry.ui, entry.completed, entry.updateProbe, entry.probeTime) - } - - /** - * Probe for an application being updated. - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update has been triggered - */ - def checkForUpdates(appId: String, attemptId: Option[String]): Boolean = { - val (entry, updated) = lookupAndUpdate(appId, attemptId) - updated - } - - /** - * Size probe, primarily for testing. - * @return size - */ + /** @return Number of cached UIs. */ def size(): Long = appCache.size() - /** - * Emptiness predicate, primarily for testing. - * @return true if the cache is empty - */ - def isEmpty: Boolean = appCache.size() == 0 - - /** - * Time a closure, returning its output. - * @param t timer - * @param f function - * @tparam T type of return value of time - * @return the result of the function. - */ private def time[T](t: Timer)(f: => T): T = { val timeCtx = t.time() try { @@ -272,27 +155,15 @@ private[history] class ApplicationCache( * @throws NoSuchElementException if there is no matching element */ @throws[NoSuchElementException] - def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { - + private def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { logDebug(s"Loading application Entry $appId/$attemptId") metrics.loadCount.inc() - time(metrics.loadTimer) { + val loadedUI = time(metrics.loadTimer) { + metrics.lookupCount.inc() operations.getAppUI(appId, attemptId) match { - case Some(LoadedAppUI(ui, updateState)) => - val completed = ui.getApplicationInfoList.exists(_.attempts.last.completed) - if (completed) { - // completed spark UIs are attached directly - operations.attachSparkUI(appId, attemptId, ui, completed) - } else { - // incomplete UIs have the cache-check filter put in front of them. - ApplicationCacheCheckFilterRelay.registerFilter(ui, appId, attemptId) - operations.attachSparkUI(appId, attemptId, ui, completed) - } - // build the cache entry - val now = clock.getTimeMillis() - val entry = new CacheEntry(ui, completed, updateState, now) - logDebug(s"Loaded application $appId/$attemptId -> $entry") - entry + case Some(loadedUI) => + logDebug(s"Loaded application $appId/$attemptId") + loadedUI case None => metrics.lookupFailureCount.inc() // guava's cache logs via java.util log, so is of limited use. Hence: our own message @@ -301,32 +172,20 @@ private[history] class ApplicationCache( attemptId.map { id => s" attemptId '$id'" }.getOrElse(" and no attempt Id")) } } - } - - /** - * Split up an `applicationId/attemptId` or `applicationId` key into the separate pieces. - * - * @param appAndAttempt combined key - * @return a tuple of the application ID and, if present, the attemptID - */ - def splitAppAndAttemptKey(appAndAttempt: String): (String, Option[String]) = { - val parts = appAndAttempt.split("/") - require(parts.length == 1 || parts.length == 2, s"Invalid app key $appAndAttempt") - val appId = parts(0) - val attemptId = if (parts.length > 1) Some(parts(1)) else None - (appId, attemptId) - } - - /** - * Merge an appId and optional attempt Id into a key of the form `applicationId/attemptId`. - * - * If there is an `attemptId`; `applicationId` if not. - * @param appId application ID - * @param attemptId optional attempt ID - * @return a unified string - */ - def mergeAppAndAttemptToKey(appId: String, attemptId: Option[String]): String = { - appId + attemptId.map { id => s"/$id" }.getOrElse("") + try { + val completed = loadedUI.ui.getApplicationInfoList.exists(_.attempts.last.completed) + if (!completed) { + // incomplete UIs have the cache-check filter put in front of them. + registerFilter(new CacheKey(appId, attemptId), loadedUI) + } + operations.attachSparkUI(appId, attemptId, loadedUI.ui, completed) + new CacheEntry(loadedUI, completed) + } catch { + case e: Exception => + logWarning(s"Failed to initialize application UI for $appId/$attemptId", e) + operations.detachSparkUI(appId, attemptId, loadedUI.ui) + throw e + } } /** @@ -347,6 +206,26 @@ private[history] class ApplicationCache( sb.append("----\n") sb.toString() } + + /** + * Register a filter for the web UI which checks for updates to the given app/attempt + * @param ui Spark UI to attach filters to + * @param appId application ID + * @param attemptId attempt ID + */ + private def registerFilter(key: CacheKey, loadedUI: LoadedAppUI): Unit = { + require(loadedUI != null) + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) + val filter = new ApplicationCacheCheckFilter(key, loadedUI, this) + val holder = new FilterHolder(filter) + require(loadedUI.ui.getHandlers != null, "null handlers") + loadedUI.ui.getHandlers.foreach { handler => + handler.addFilter(holder, "/*", enumDispatcher) + } + } + + def invalidate(key: CacheKey): Unit = appCache.invalidate(key) + } /** @@ -355,19 +234,14 @@ private[history] class ApplicationCache( * @param ui Spark UI * @param completed Flag to indicated that the application has completed (and so * does not need refreshing). - * @param updateProbe function to call to see if the application has been updated and - * therefore that the cached value needs to be refreshed. - * @param probeTime Times in milliseconds when the probe was last executed. */ private[history] final class CacheEntry( - val ui: SparkUI, - val completed: Boolean, - val updateProbe: () => Boolean, - var probeTime: Long) { + val loadedUI: LoadedAppUI, + val completed: Boolean) { /** string value is for test assertions */ override def toString: String = { - s"UI $ui, completed=$completed, probeTime=$probeTime" + s"UI ${loadedUI.ui}, completed=$completed" } } @@ -396,23 +270,17 @@ private[history] class CacheMetrics(prefix: String) extends Source { val evictionCount = new Counter() val loadCount = new Counter() val loadTimer = new Timer() - val updateProbeCount = new Counter() - val updateProbeTimer = new Timer() - val updateTriggeredCount = new Counter() /** all the counters: for registration and string conversion. */ private val counters = Seq( ("lookup.count", lookupCount), ("lookup.failure.count", lookupFailureCount), ("eviction.count", evictionCount), - ("load.count", loadCount), - ("update.probe.count", updateProbeCount), - ("update.triggered.count", updateTriggeredCount)) + ("load.count", loadCount)) /** all metrics, including timers */ private val allMetrics = counters ++ Seq( - ("load.timer", loadTimer), - ("update.probe.timer", updateProbeTimer)) + ("load.timer", loadTimer)) /** * Name of metric source @@ -498,23 +366,11 @@ private[history] trait ApplicationCacheOperations { * Implementation note: there's some abuse of a shared global entry here because * the configuration data passed to the servlet is just a string:string map. */ -private[history] class ApplicationCacheCheckFilter() extends Filter with Logging { - - import ApplicationCacheCheckFilterRelay._ - var appId: String = _ - var attemptId: Option[String] = _ - - /** - * Bind the app and attempt ID, throwing an exception if no application ID was provided. - * @param filterConfig configuration - */ - override def init(filterConfig: FilterConfig): Unit = { - - appId = Option(filterConfig.getInitParameter(APP_ID)) - .getOrElse(throw new ServletException(s"Missing Parameter $APP_ID")) - attemptId = Option(filterConfig.getInitParameter(ATTEMPT_ID)) - logDebug(s"initializing filter $this") - } +private[history] class ApplicationCacheCheckFilter( + key: CacheKey, + loadedUI: LoadedAppUI, + cache: ApplicationCache) + extends Filter with Logging { /** * Filter the request. @@ -543,123 +399,24 @@ private[history] class ApplicationCacheCheckFilter() extends Filter with Logging // if the request is for an attempt, check to see if it is in need of delete/refresh // and have the cache update the UI if so - if (operation=="HEAD" || operation=="GET" - && checkForUpdates(requestURI, appId, attemptId)) { - // send a redirect back to the same location. This will be routed - // to the *new* UI - logInfo(s"Application Attempt $appId/$attemptId updated; refreshing") + loadedUI.lock.readLock().lock() + if (loadedUI.valid) { + try { + chain.doFilter(request, response) + } finally { + loadedUI.lock.readLock.unlock() + } + } else { + loadedUI.lock.readLock.unlock() + cache.invalidate(key) val queryStr = Option(httpRequest.getQueryString).map("?" + _).getOrElse("") val redirectUrl = httpResponse.encodeRedirectURL(requestURI + queryStr) httpResponse.sendRedirect(redirectUrl) - } else { - chain.doFilter(request, response) } } - override def destroy(): Unit = { - } + override def init(config: FilterConfig): Unit = { } - override def toString: String = s"ApplicationCacheCheckFilter for $appId/$attemptId" -} + override def destroy(): Unit = { } -/** - * Global state for the [[ApplicationCacheCheckFilter]] instances, so that they can relay cache - * probes to the cache. - * - * This is an ugly workaround for the limitation of servlets and filters in the Java servlet - * API; they are still configured on the model of a list of classnames and configuration - * strings in a `web.xml` field, rather than a chain of instances wired up by hand or - * via an injection framework. There is no way to directly configure a servlet filter instance - * with a reference to the application cache which is must use: some global state is needed. - * - * Here, [[ApplicationCacheCheckFilter]] is that global state; it relays all requests - * to the singleton [[ApplicationCache]] - * - * The field `applicationCache` must be set for the filters to work - - * this is done during the construction of [[ApplicationCache]], which requires that there - * is only one cache serving requests through the WebUI. - * - * *Important* In test runs, if there is more than one [[ApplicationCache]], the relay logic - * will break: filters may not find instances. Tests must not do that. - * - */ -private[history] object ApplicationCacheCheckFilterRelay extends Logging { - // name of the app ID entry in the filter configuration. Mandatory. - val APP_ID = "appId" - - // name of the attempt ID entry in the filter configuration. Optional. - val ATTEMPT_ID = "attemptId" - - // name of the filter to register - val FILTER_NAME = "org.apache.spark.deploy.history.ApplicationCacheCheckFilter" - - /** the application cache to relay requests to */ - @volatile - private var applicationCache: Option[ApplicationCache] = None - - /** - * Set the application cache. Logs a warning if it is overwriting an existing value - * @param cache new cache - */ - def setApplicationCache(cache: ApplicationCache): Unit = { - applicationCache.foreach( c => logWarning(s"Overwriting application cache $c")) - applicationCache = Some(cache) - } - - /** - * Reset the application cache - */ - def resetApplicationCache(): Unit = { - applicationCache = None - } - - /** - * Check to see if there has been an update - * @param requestURI URI the request came in on - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update was loaded for the app/attempt - */ - def checkForUpdates(requestURI: String, appId: String, attemptId: Option[String]): Boolean = { - - logDebug(s"Checking $appId/$attemptId from $requestURI") - applicationCache match { - case Some(cache) => - try { - cache.checkForUpdates(appId, attemptId) - } catch { - case ex: Exception => - // something went wrong. Keep going with the existing UI - logWarning(s"When checking for $appId/$attemptId from $requestURI", ex) - false - } - - case None => - logWarning("No application cache instance defined") - false - } - } - - - /** - * Register a filter for the web UI which checks for updates to the given app/attempt - * @param ui Spark UI to attach filters to - * @param appId application ID - * @param attemptId attempt ID - */ - def registerFilter( - ui: SparkUI, - appId: String, - attemptId: Option[String] ): Unit = { - require(ui != null) - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) - val holder = new FilterHolder() - holder.setClassName(FILTER_NAME) - holder.setInitParameter(APP_ID, appId) - attemptId.foreach( id => holder.setInitParameter(ATTEMPT_ID, id)) - require(ui.getHandlers != null, "null handlers") - ui.getHandlers.foreach { handler => - handler.addFilter(holder, "/*", enumDispatcher) - } - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 5cb48ca3e60b0..38f0d6f2afa5e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.history +import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.zip.ZipOutputStream import scala.xml.Node @@ -48,30 +49,44 @@ private[spark] case class ApplicationHistoryInfo( } /** - * A probe which can be invoked to see if a loaded Web UI has been updated. - * The probe is expected to be relative purely to that of the UI returned - * in the same [[LoadedAppUI]] instance. That is, whenever a new UI is loaded, - * the probe returned with it is the one that must be used to check for it - * being out of date; previous probes must be discarded. - */ -private[history] abstract class HistoryUpdateProbe { - /** - * Return true if the history provider has a later version of the application - * attempt than the one against this probe was constructed. - * @return - */ - def isUpdated(): Boolean -} - -/** - * All the information returned from a call to `getAppUI()`: the new UI - * and any required update state. + * A loaded UI for a Spark application. + * + * Loaded UIs are valid once created, and can be invalidated once the history provider detects + * changes in the underlying app data (e.g. an updated event log). Invalidating a UI does not + * unload it; it just signals the [[ApplicationCache]] that the UI should not be used to serve + * new requests. + * + * Reloading of the UI with new data requires collaboration between the cache and the provider; + * the provider invalidates the UI when it detects updated information, and the cache invalidates + * the cache entry when it detects the UI has been invalidated. That will trigger a callback + * on the provider to finally clean up any UI state. The cache should hold read locks when + * using the UI, and the provider should grab the UI's write lock before making destructive + * operations. + * + * Note that all this means that an invalidated UI will still stay in-memory, and any resources it + * references will remain open, until the cache either sees that it's invalidated, or evicts it to + * make room for another UI. + * * @param ui Spark UI - * @param updateProbe probe to call to check on the update state of this application attempt */ -private[history] case class LoadedAppUI( - ui: SparkUI, - updateProbe: () => Boolean) +private[history] case class LoadedAppUI(ui: SparkUI) { + + val lock = new ReentrantReadWriteLock() + + @volatile private var _valid = true + + def valid: Boolean = _valid + + def invalidate(): Unit = { + lock.writeLock().lock() + try { + _valid = false + } finally { + lock.writeLock().unlock() + } + } + +} private[history] abstract class ApplicationHistoryProvider { @@ -145,4 +160,10 @@ private[history] abstract class ApplicationHistoryProvider { * @return html text to display when the application list is empty */ def getEmptyListingHtml(): Seq[Node] = Seq.empty + + /** + * Called when an application UI is unloaded from the history server. + */ + def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index cf97597b484d8..f16dddea9f784 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, UUID} +import java.util.{Date, ServiceLoader, UUID} import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -26,8 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.xml.Node -import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude} -import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, Path} @@ -42,6 +41,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.{AppStatusListener, AppStatusStore, AppStatusStoreMetadata, KVUtils} import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI @@ -61,9 +61,6 @@ import org.apache.spark.util.kvstore._ * and update or create a matching application info element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. - * - When [[updateProbe()]] is invoked to check if a loaded [[SparkUI]] - * instance is out of date, the log size of the cached instance is checked against the app last - * loaded by [[checkForLogs]]. * * The use of log size, rather than simply relying on modification times, is needed to * address the following issues @@ -125,23 +122,30 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) - private val storePath = conf.get(LOCAL_STORE_DIR) + private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => + require(path.isDirectory(), s"Configured store directory ($path) does not exist.") val dbPath = new File(path, "listing.ldb") - val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, logDir.toString()) + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, + AppStatusStore.CURRENT_VERSION, logDir.toString()) try { open(new File(path, "listing.ldb"), metadata) } catch { + // If there's an error, remove the listing database and any existing UI database + // from the store directory, since it's extremely likely that they'll all contain + // incompatible information. case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") - Utils.deleteRecursively(dbPath) + path.listFiles().foreach(Utils.deleteRecursively) open(new File(path, "listing.ldb"), metadata) } }.getOrElse(new InMemoryStore()) + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -165,7 +169,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - // Conf option used for testing the initialization code. val initThread = initialize() private[history] def initialize(): Thread = { @@ -268,42 +271,100 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getLastUpdatedTime(): Long = lastScanTime.get() override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { - try { - val appInfo = load(appId) - appInfo.attempts - .find(_.info.attemptId == attemptId) - .map { attempt => - val replayBus = new ReplayListenerBus() - val ui = { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.info.name, - HistoryServer.getAttemptURI(appId, attempt.info.attemptId), - Some(attempt.info.lastUpdated.getTime()), attempt.info.startTime.getTime()) - // Do not call ui.bind() to avoid creating a new server for each application - } + val app = try { + load(appId) + } catch { + case _: NoSuchElementException => + return None + } + + val attempt = app.attempts.find(_.info.attemptId == attemptId).orNull + if (attempt == null) { + return None + } - val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) - - val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - assert(appListener.appId.isDefined) - ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") - ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) - // make sure to set admin acls before view acls so they are properly picked up - val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") - ui.getSecurityManager.setAdminAcls(adminAcls) - ui.getSecurityManager.setViewAcls(attempt.info.sparkUser, - appListener.viewAcls.getOrElse("")) - val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + - appListener.adminAclsGroups.getOrElse("") - ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) - ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - LoadedAppUI(ui, () => updateProbe(appId, attemptId, attempt.fileSize)) + val conf = this.conf.clone() + val secManager = new SecurityManager(conf) + + secManager.setAcls(HISTORY_UI_ACLS_ENABLE) + // make sure to set admin acls before view acls so they are properly picked up + secManager.setAdminAcls(HISTORY_UI_ADMIN_ACLS + "," + attempt.adminAcls.getOrElse("")) + secManager.setViewAcls(attempt.info.sparkUser, attempt.viewAcls.getOrElse("")) + secManager.setAdminAclsGroups(HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + attempt.adminAclsGroups.getOrElse("")) + secManager.setViewAclsGroups(attempt.viewAclsGroups.getOrElse("")) + + val replayBus = new ReplayListenerBus() + + val uiStorePath = storePath.map { path => getStorePath(path, appId, attemptId) } + + val (kvstore, needReplay) = uiStorePath match { + case Some(path) => + try { + val _replay = !path.isDirectory() + (createDiskStore(path, conf), _replay) + } catch { + case e: Exception => + // Get rid of the old data and re-create it. The store is either old or corrupted. + logWarning(s"Failed to load disk store $uiStorePath for $appId.", e) + Utils.deleteRecursively(path) + (createDiskStore(path, conf), true) } + + case _ => + (new InMemoryStore(), true) + } + + val listener = if (needReplay) { + val _listener = new AppStatusListener(kvstore, conf, false) + replayBus.addListener(_listener) + Some(_listener) + } else { + None + } + + val loadedUI = { + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, + l => replayBus.addListener(l), + secManager, + app.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + attempt.info.startTime.getTime(), + appSparkVersion = attempt.info.appSparkVersion) + LoadedAppUI(ui) + } + + try { + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, loadedUI.ui) + listeners.foreach(replayBus.addListener) + } + + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + listener.foreach(_.flush()) } catch { - case _: FileNotFoundException => None - case _: NoSuchElementException => None + case e: Exception => + try { + kvstore.close() + } catch { + case _e: Exception => logInfo("Error closing store.", _e) + } + uiStorePath.foreach(Utils.deleteRecursively) + if (e.isInstanceOf[FileNotFoundException]) { + return None + } else { + throw e + } + } + + synchronized { + activeUIs((appId, attemptId)) = loadedUI } + + Some(loadedUI) } override def getEmptyListingHtml(): Seq[Node] = { @@ -332,11 +393,40 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) initThread.interrupt() initThread.join() } + Seq(pool, replayExecutor).foreach { executor => + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } } finally { + activeUIs.foreach { case (_, loadedUI) => loadedUI.ui.store.close() } + activeUIs.clear() listing.close() } } + override def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { + val uiOption = synchronized { + activeUIs.remove((appId, attemptId)) + } + uiOption.foreach { loadedUI => + loadedUI.lock.writeLock().lock() + try { + loadedUI.ui.store.close() + } finally { + loadedUI.lock.writeLock().unlock() + } + + // If the UI is not valid, delete its files from disk, if any. This relies on the fact that + // ApplicationCache will never call this method concurrently with getAppUI() for the same + // appId / attemptId. + if (!loadedUI.valid && storePath.isDefined) { + Utils.deleteRecursively(getStorePath(storePath.get, appId, attemptId)) + } + } + } + /** * Builds the application list based on the current contents of the log directory. * Tries to reuse as much of the data already in memory as possible, by not reading @@ -475,7 +565,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || - eventString.startsWith(LOG_START_EVENT_PREFIX) + eventString.startsWith(LOG_START_EVENT_PREFIX) || + eventString.startsWith(ENV_UPDATE_EVENT_PREFIX) } val logPath = fileStatus.getPath() @@ -486,8 +577,19 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, isApplicationCompleted(fileStatus), bus, eventsFilter) - listener.applicationInfo.foreach(addListing) - listing.write(LogInfo(logPath.toString(), fileStatus.getLen())) + listener.applicationInfo.foreach { app => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + } + + addListing(app) + } + listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) } /** @@ -546,16 +648,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns - * an `ApplicationEventListener` instance with event data captured from the replay. - * `ReplayEventsFilter` determines what events are replayed and can therefore limit the - * data captured in the returned `ApplicationEventListener` instance. + * Replays the events in the specified log file on the supplied `ReplayListenerBus`. + * `ReplayEventsFilter` determines what events are replayed. */ private def replay( eventLog: FileStatus, appCompleted: Boolean, bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): ApplicationEventListener = { + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, @@ -566,11 +666,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // after it's created, so we get a file size that is no bigger than what is actually read. val logInput = EventLoggingListener.openEventLog(logPath, fs) try { - val appListener = new ApplicationEventListener - bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) logInfo(s"Finished replaying $logPath") - appListener } finally { logInput.close() } @@ -613,32 +710,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return true iff a newer version of the UI is available. The check is based on whether the - * fileSize for the currently loaded UI is smaller than the file size the last time - * the logs were loaded. - * - * This is a very cheap operation -- the work of loading the new attempt was already done - * by [[checkForLogs]]. - * @param appId application to probe - * @param attemptId attempt to probe - * @param prevFileSize the file size of the logs for the currently displayed UI - */ - private def updateProbe( - appId: String, - attemptId: Option[String], - prevFileSize: Long)(): Boolean = { - try { - val attempt = getAttempt(appId, attemptId) - val logPath = fs.makeQualified(new Path(logDir, attempt.logPath)) - recordedFileSize(logPath) > prevFileSize - } catch { - case _: NoSuchElementException => - logDebug(s"Application Attempt $appId/$attemptId not found") - false - } - } - /** * Return the last known size of the given event log, recorded the last time the file * system scanner detected a change in the file. @@ -682,6 +753,16 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.write(newAppInfo) } + private def createDiskStore(path: File, conf: SparkConf): KVStore = { + val metadata = new AppStatusStoreMetadata(AppStatusStore.CURRENT_VERSION) + KVUtils.open(path, metadata) + } + + private def getStorePath(path: File, appId: String, attemptId: Option[String]): File = { + val fileName = appId + attemptId.map("_" + _).getOrElse("") + ".ldb" + new File(path, fileName) + } + /** For testing. Returns internal data about a single attempt. */ private[history] def getAttempt(appId: String, attemptId: Option[String]): AttemptInfoWrapper = { load(appId).attempts.find(_.info.attemptId == attemptId).getOrElse( @@ -699,6 +780,8 @@ private[history] object FsHistoryProvider { private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" + private val ENV_UPDATE_EVENT_PREFIX = "{\"Event\":\"SparkListenerEnvironmentUpdate\"," + /** * Current version of the data written to the listing database. When opening an existing * db, if the version does not match this value, the FsHistoryProvider will throw away @@ -708,17 +791,22 @@ private[history] object FsHistoryProvider { } private[history] case class FsHistoryProviderMetadata( - version: Long, - logDir: String) + version: Long, + uiVersion: Long, + logDir: String) private[history] case class LogInfo( - @KVIndexParam logPath: String, - fileSize: Long) + @KVIndexParam logPath: String, + fileSize: Long) private[history] class AttemptInfoWrapper( val info: v1.ApplicationAttemptInfo, val logPath: String, - val fileSize: Long) { + val fileSize: Long, + val adminAcls: Option[String], + val viewAcls: Option[String], + val adminAclsGroups: Option[String], + val viewAclsGroups: Option[String]) { def toAppAttemptInfo(): ApplicationAttemptInfo = { ApplicationAttemptInfo(info.attemptId, info.startTime.getTime(), @@ -769,6 +857,14 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.completed = true } + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + } + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case SparkListenerLogStart(sparkVersion) => attempt.appSparkVersion = sparkVersion @@ -809,6 +905,11 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends var completed = false var appSparkVersion = "" + var adminAcls: Option[String] = None + var viewAcls: Option[String] = None + var adminAclsGroups: Option[String] = None + var viewAclsGroups: Option[String] = None + def toView(): AttemptInfoWrapper = { val apiInfo = new v1.ApplicationAttemptInfo( attemptId, @@ -822,7 +923,11 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends new AttemptInfoWrapper( apiInfo, logPath, - fileSize) + fileSize, + adminAcls, + viewAcls, + adminAclsGroups, + viewAclsGroups) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d9c8fda99ef97..b822a48e98e91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -106,8 +106,8 @@ class HistoryServer( } } - def getSparkUI(appKey: String): Option[SparkUI] = { - appCache.getSparkUI(appKey) + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + appCache.withSparkUI(appId, attemptId)(fn) } initialize() @@ -140,7 +140,6 @@ class HistoryServer( override def stop() { super.stop() provider.stop() - appCache.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ @@ -158,6 +157,7 @@ class HistoryServer( override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") ui.getHandlers.foreach(detachHandler) + provider.onUIDetached(appId, attemptId, ui) } /** @@ -224,15 +224,13 @@ class HistoryServer( */ private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = { try { - appCache.get(appId, attemptId) + appCache.withSparkUI(appId, attemptId) { _ => + // Do nothing, just force the UI to load. + } true } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - false - - case cause: Exception => throw cause - } + case NonFatal(e: NoSuchElementException) => + false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala deleted file mode 100644 index 6da8865cd10d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -/** - * A simple listener for application events. - * - * This listener expects to hear events from a single application only. If events - * from multiple applications are seen, the behavior is unspecified. - */ -private[spark] class ApplicationEventListener extends SparkListener { - var appName: Option[String] = None - var appId: Option[String] = None - var appAttemptId: Option[String] = None - var sparkUser: Option[String] = None - var startTime: Option[Long] = None - var endTime: Option[Long] = None - var viewAcls: Option[String] = None - var adminAcls: Option[String] = None - var viewAclsGroups: Option[String] = None - var adminAclsGroups: Option[String] = None - var appSparkVersion: Option[String] = None - - override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { - appName = Some(applicationStart.appName) - appId = applicationStart.appId - appAttemptId = applicationStart.appAttemptId - startTime = Some(applicationStart.time) - sparkUser = Some(applicationStart.sparkUser) - } - - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { - endTime = Some(applicationEnd.time) - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - val allProperties = environmentDetails("Spark Properties").toMap - viewAcls = allProperties.get("spark.ui.view.acls") - adminAcls = allProperties.get("spark.admin.acls") - viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - adminAclsGroups = allProperties.get("spark.admin.acls.groups") - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerLogStart(sparkVersion) => - appSparkVersion = Some(sparkVersion) - case _ => - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f120685c941df..cd43612fae357 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -34,12 +34,22 @@ import org.apache.spark.util.kvstore.KVStore * A Spark listener that writes application information to a data store. The types written to the * store are defined in the `storeTypes.scala` file and are based on the public REST API. */ -private class AppStatusListener(kvstore: KVStore) extends SparkListener with Logging { +private[spark] class AppStatusListener( + kvstore: KVStore, + conf: SparkConf, + live: Boolean) extends SparkListener with Logging { + + import config._ private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var coresPerTask: Int = 1 + // How often to update live entities. -1 means "never update" when replaying applications, + // meaning only the last write will happen. For live applications, this avoids a few + // operations that we can live without when rapidly processing incoming task events. + private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). private val liveStages = new HashMap[(Int, Int), LiveStage]() @@ -110,13 +120,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.totalCores = event.executorInfo.totalCores exec.maxTasks = event.executorInfo.totalCores / coresPerTask exec.executorLogs = event.executorInfo.logUrlMap - update(exec) + liveUpdate(exec, System.nanoTime()) } override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => exec.isActive = false - update(exec) + update(exec, System.nanoTime()) } } @@ -139,21 +149,25 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { liveExecutors.get(execId).foreach { exec => exec.isBlacklisted = blacklisted - update(exec) + liveUpdate(exec, System.nanoTime()) } } private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + val now = System.nanoTime() + // Implicitly (un)blacklist every executor associated with the node. liveExecutors.values.foreach { exec => if (exec.hostname == host) { exec.isBlacklisted = blacklisted - update(exec) + liveUpdate(exec, now) } } } override def onJobStart(event: SparkListenerJobStart): Unit = { + val now = System.nanoTime() + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. // This may be an over-estimate because the job start event references all of the result // stages' transitive stage dependencies, but some of these stages might be skipped if their @@ -178,7 +192,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log jobGroup, numTasks) liveJobs.put(event.jobId, job) - update(job) + liveUpdate(job, now) event.stageInfos.foreach { stageInfo => // A new job submission may re-use an existing stage, so this code needs to do an update @@ -186,7 +200,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log val stage = getOrCreateStage(stageInfo) stage.jobs :+= job stage.jobIds += event.jobId - update(stage) + liveUpdate(stage, now) } } @@ -198,11 +212,12 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } job.completionTime = Some(new Date(event.time)) - update(job) + update(job, System.nanoTime()) } } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val now = System.nanoTime() val stage = getOrCreateStage(event.stageInfo) stage.status = v1.StageStatus.ACTIVE stage.schedulingPool = Option(event.properties).flatMap { p => @@ -218,38 +233,39 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage.jobs.foreach { job => job.completedStages = job.completedStages - event.stageInfo.stageId job.activeStages += 1 - update(job) + liveUpdate(job, now) } event.stageInfo.rddInfos.foreach { info => if (info.storageLevel.isValid) { - update(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + liveUpdate(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info)), now) } } - update(stage) + liveUpdate(stage, now) } override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val now = System.nanoTime() val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) liveTasks.put(event.taskInfo.taskId, task) - update(task) + liveUpdate(task, now) liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) - update(stage) + maybeUpdate(stage, now) stage.jobs.foreach { job => job.activeTasks += 1 - update(job) + maybeUpdate(job, now) } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => exec.activeTasks += 1 exec.totalTasks += 1 - update(exec) + maybeUpdate(exec, now) } } @@ -257,7 +273,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log // Call update on the task so that the "getting result" time is written to the store; the // value is part of the mutable TaskInfo state that the live entity already references. liveTasks.get(event.taskInfo.taskId).foreach { task => - update(task) + maybeUpdate(task, System.nanoTime()) } } @@ -267,6 +283,8 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log return } + val now = System.nanoTime() + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => val errorMessage = event.reason match { case Success => @@ -283,7 +301,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task) + update(task, now) delta }.orNull @@ -301,13 +319,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage.activeTasks -= 1 stage.completedTasks += completedDelta stage.failedTasks += failedDelta - update(stage) + maybeUpdate(stage, now) stage.jobs.foreach { job => job.activeTasks -= 1 job.completedTasks += completedDelta job.failedTasks += failedDelta - update(job) + maybeUpdate(job, now) } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -317,7 +335,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log if (metricsDelta != null) { esummary.metrics.update(metricsDelta) } - update(esummary) + maybeUpdate(esummary, now) } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -333,12 +351,13 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.completedTasks += completedDelta exec.failedTasks += failedDelta exec.totalDuration += event.taskInfo.duration - update(exec) + maybeUpdate(exec, now) } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + val now = System.nanoTime() stage.info = event.stageInfo // Because of SPARK-20205, old event logs may contain valid stages without a submission time @@ -349,7 +368,6 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE case _ => v1.StageStatus.SKIPPED } - update(stage) stage.jobs.foreach { job => stage.status match { @@ -362,11 +380,11 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log job.failedStages += 1 } job.activeStages -= 1 - update(job) + liveUpdate(job, now) } - stage.executorSummaries.values.foreach(update) - update(stage) + stage.executorSummaries.values.foreach(update(_, now)) + update(stage, now) } } @@ -381,7 +399,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } exec.isActive = true exec.maxMemory = event.maxMem - update(exec) + liveUpdate(exec, System.nanoTime()) } override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { @@ -394,19 +412,21 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + val now = System.nanoTime() + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => liveTasks.get(taskId).foreach { task => val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) val delta = task.updateMetrics(metrics) - update(task) + maybeUpdate(task, now) liveStages.get((sid, sAttempt)).foreach { stage => stage.metrics.update(delta) - update(stage) + maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) esummary.metrics.update(delta) - update(esummary) + maybeUpdate(esummary, now) } } } @@ -419,7 +439,18 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } } + /** Flush all live entities' data to the underlying store. */ + def flush(): Unit = { + val now = System.nanoTime() + liveStages.values.foreach(update(_, now)) + liveJobs.values.foreach(update(_, now)) + liveExecutors.values.foreach(update(_, now)) + liveTasks.values.foreach(update(_, now)) + liveRDDs.values.foreach(update(_, now)) + } + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId // Whether values are being added to or removed from the existing accounting. @@ -494,7 +525,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log } rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) - update(rdd) + update(rdd, now) } maybeExec.foreach { exec => @@ -508,9 +539,7 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) exec.diskUsed = newValue(exec.diskUsed, diskDelta) exec.rddBlocks += rddBlocksDelta - if (exec.hasMemoryInfo || rddBlocksDelta != 0) { - update(exec) - } + maybeUpdate(exec, now) } } @@ -524,8 +553,22 @@ private class AppStatusListener(kvstore: KVStore) extends SparkListener with Log stage } - private def update(entity: LiveEntity): Unit = { - entity.write(kvstore) + private def update(entity: LiveEntity, now: Long): Unit = { + entity.write(kvstore, now) + } + + /** Update a live entity only if it hasn't been updated in the last configured period. */ + private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { + if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + update(entity, now) + } + } + + /** Update an entity only if in a live app; avoids redundant writes when replaying logs. */ + private def liveUpdate(entity: LiveEntity, now: Long): Unit = { + if (live) { + update(entity, now) + } } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala new file mode 100644 index 0000000000000..2927a3227cbef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Arrays, List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.status.api.v1 +import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} + +/** + * A wrapper around a KVStore that provides methods for accessing the API data stored within. + */ +private[spark] class AppStatusStore(store: KVStore) { + + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { + val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty()) { + it.filter { job => statuses.contains(job.status) }.toSeq + } else { + it.toSeq + } + } + + def job(jobId: Int): v1.JobData = { + store.read(classOf[JobDataWrapper], jobId).info + } + + def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { + store.view(classOf[ExecutorSummaryWrapper]).index("active").reverse().first(true) + .last(true).asScala.map(_.info).toSeq + } + + def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { + val it = store.view(classOf[StageDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty) { + it.filter { s => statuses.contains(s.status) }.toSeq + } else { + it.toSeq + } + } + + def stageData(stageId: Int): Seq[v1.StageData] = { + store.view(classOf[StageDataWrapper]).index("stageId").first(stageId).last(stageId) + .asScala.map(_.info).toSeq + } + + def stageAttempt(stageId: Int, stageAttemptId: Int): v1.StageData = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).info + } + + def taskSummary( + stageId: Int, + stageAttemptId: Int, + quantiles: Array[Double]): v1.TaskMetricDistributions = { + + val stage = Array(stageId, stageAttemptId) + + val rawMetrics = store.view(classOf[TaskDataWrapper]) + .index("stage") + .first(stage) + .last(stage) + .asScala + .flatMap(_.info.taskMetrics) + .toList + .view + + def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = + Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) + + // We need to do a lot of similar munging to nested metrics here. For each one, + // we want (a) extract the values for nested metrics (b) make a distribution for each metric + // (c) shove the distribution into the right field in our return type and (d) only return + // a result if the option is defined for any of the tasks. MetricHelper is a little util + // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just + // implement one "build" method, which just builds the quantiles for each field. + + val inputMetrics = + new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics + + def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( + bytesRead = submetricQuantiles(_.bytesRead), + recordsRead = submetricQuantiles(_.recordsRead) + ) + }.build + + val outputMetrics = + new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics + + def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( + bytesWritten = submetricQuantiles(_.bytesWritten), + recordsWritten = submetricQuantiles(_.recordsWritten) + ) + }.build + + val shuffleReadMetrics = + new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = + raw.shuffleReadMetrics + + def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( + readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, + readRecords = submetricQuantiles(_.recordsRead), + remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), + remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), + localBlocksFetched = submetricQuantiles(_.localBlocksFetched), + totalBlocksFetched = submetricQuantiles { s => + s.localBlocksFetched + s.remoteBlocksFetched + }, + fetchWaitTime = submetricQuantiles(_.fetchWaitTime) + ) + }.build + + val shuffleWriteMetrics = + new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = + raw.shuffleWriteMetrics + + def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( + writeBytes = submetricQuantiles(_.bytesWritten), + writeRecords = submetricQuantiles(_.recordsWritten), + writeTime = submetricQuantiles(_.writeTime) + ) + }.build + + new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), + executorRunTime = metricQuantiles(_.executorRunTime), + executorCpuTime = metricQuantiles(_.executorCpuTime), + resultSize = metricQuantiles(_.resultSize), + jvmGcTime = metricQuantiles(_.jvmGcTime), + resultSerializationTime = metricQuantiles(_.resultSerializationTime), + memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), + diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), + inputMetrics = inputMetrics, + outputMetrics = outputMetrics, + shuffleReadMetrics = shuffleReadMetrics, + shuffleWriteMetrics = shuffleWriteMetrics + ) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val stageKey = Array(stageId, stageAttemptId) + val base = store.view(classOf[TaskDataWrapper]) + val indexed = sortBy match { + case v1.TaskSorting.ID => + base.index("stage").first(stageKey).last(stageKey) + case v1.TaskSorting.INCREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) + case v1.TaskSorting.DECREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) + .reverse() + } + indexed.skip(offset).max(length).asScala.map(_.info).toSeq + } + + def rddList(): Seq[v1.RDDStorageInfo] = { + store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).toSeq + } + + def rdd(rddId: Int): v1.RDDStorageInfo = { + store.read(classOf[RDDStorageInfoWrapper], rddId).info + } + + def close(): Unit = { + store.close() + } + +} + +private[spark] object AppStatusStore { + + val CURRENT_VERSION = 1L + + /** + * Create an in-memory store for a live application. + * + * @param conf Configuration. + * @param bus Where to attach the listener to populate the store. + */ + def createLiveStore(conf: SparkConf, bus: LiveListenerBus): AppStatusStore = { + val store = new InMemoryStore() + val stateStore = new AppStatusStore(store) + bus.addToStatusQueue(new AppStatusListener(store, conf, true)) + stateStore + } + +} + +/** + * Helper for getting distributions from nested metric types. + */ +private abstract class MetricHelper[I, O]( + rawMetrics: Seq[v1.TaskMetrics], + quantiles: Array[Double]) { + + def getSubmetrics(raw: v1.TaskMetrics): I + + def build: O + + val data: Seq[I] = rawMetrics.map(getSubmetrics) + + /** applies the given function to all input metrics, and returns the quantiles */ + def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { + Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 63fa36580bc7d..041dfe1ef915e 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -37,8 +37,11 @@ import org.apache.spark.util.kvstore.KVStore */ private[spark] abstract class LiveEntity { - def write(store: KVStore): Unit = { + var lastWriteTime = 0L + + def write(store: KVStore, now: Long): Unit = { store.write(doUpdate()) + lastWriteTime = now } /** @@ -204,7 +207,7 @@ private class LiveTask( newAccumulatorInfos(info.accumulables), errorMessage, Option(recordedMetrics)) - new TaskDataWrapper(task) + new TaskDataWrapper(task, stageId, stageAttemptId) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index f17b637754826..9d3833086172f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -248,7 +248,13 @@ private[spark] object ApiRootResource { * interface needed for them all to expose application info as json. */ private[spark] trait UIRoot { - def getSparkUI(appKey: String): Option[SparkUI] + /** + * Runs some code with the current SparkUI instance for the app / attempt. + * + * @throws NoSuchElementException If the app / attempt pair does not exist. + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T + def getApplicationInfoList: Iterator[ApplicationInfo] def getApplicationInfo(appId: String): Option[ApplicationInfo] @@ -293,15 +299,18 @@ private[v1] trait ApiRequestContext { * to it. If there is no such app, throw an appropriate exception */ def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - uiRoot.getSparkUI(appKey) match { - case Some(ui) => + try { + uiRoot.withSparkUI(appId, attemptId) { ui => val user = httpRequest.getRemoteUser() if (!ui.securityManager.checkUIViewPermissions(user)) { throw new ForbiddenException(raw"""user "$user" is not authorized""") } f(ui) - case None => throw new NotFoundException("no such app: " + appId) + } + } catch { + case _: NoSuchElementException => + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + throw new NotFoundException(s"no such app: $appKey") } } } diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala new file mode 100644 index 0000000000000..49144fc883e69 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config._ + +private[spark] object config { + + val LIVE_ENTITY_UPDATE_PERIOD = ConfigBuilder("spark.ui.liveUpdate.period") + .timeConf(TimeUnit.NANOSECONDS) + .createWithDefaultString("100ms") + +} diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 9579accd2cba7..a445435809f3a 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,12 +17,16 @@ package org.apache.spark.status +import java.lang.{Integer => JInteger, Long => JLong} + import com.fasterxml.jackson.annotation.JsonIgnore import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ import org.apache.spark.util.kvstore.KVIndex +private[spark] case class AppStatusStoreMetadata(version: Long) + private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { @JsonIgnore @KVIndex @@ -64,13 +68,33 @@ private[spark] class StageDataWrapper( @JsonIgnore @KVIndex def id: Array[Int] = Array(info.stageId, info.attemptId) + @JsonIgnore @KVIndex("stageId") + def stageId: Int = info.stageId + } -private[spark] class TaskDataWrapper(val info: TaskData) { +/** + * The task information is always indexed with the stage ID, since that is how the UI and API + * consume it. That means every indexed value has the stage ID and attempt ID included, aside + * from the actual data being indexed. + */ +private[spark] class TaskDataWrapper( + val info: TaskData, + val stageId: Int, + val stageAttemptId: Int) { @JsonIgnore @KVIndex def id: Long = info.taskId + @JsonIgnore @KVIndex("stage") + def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore @KVIndex("runtime") + def runtime: Array[AnyRef] = { + val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) + Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + } + } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6e94073238a56..ee645f6bf8a7a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ +import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} import org.apache.spark.storage.StorageStatusListener @@ -39,6 +40,7 @@ import org.apache.spark.util.Utils * Top level user interface for a Spark application. */ private[spark] class SparkUI private ( + val store: AppStatusStore, val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, @@ -51,7 +53,8 @@ private[spark] class SparkUI private ( var appName: String, val basePath: String, val lastUpdateTime: Option[Long] = None, - val startTime: Long) + val startTime: Long, + val appSparkVersion: String) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging @@ -61,8 +64,6 @@ private[spark] class SparkUI private ( var appId: String = _ - var appSparkVersion = org.apache.spark.SPARK_VERSION - private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -104,8 +105,12 @@ private[spark] class SparkUI private ( logInfo(s"Stopped Spark web UI at $webUrl") } - def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == this.appId) Some(this) else None + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + if (appId == this.appId) { + fn(this) + } else { + throw new NoSuchElementException() + } } def getApplicationInfoList: Iterator[ApplicationInfo] = { @@ -159,63 +164,26 @@ private[spark] object SparkUI { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) } - def createLiveUI( - sc: SparkContext, - conf: SparkConf, - jobProgressListener: JobProgressListener, - securityManager: SecurityManager, - appName: String, - startTime: Long): SparkUI = { - create(Some(sc), conf, - sc.listenerBus.addToStatusQueue, - securityManager, appName, jobProgressListener = Some(jobProgressListener), - startTime = startTime) - } - - def createHistoryUI( - conf: SparkConf, - listenerBus: SparkListenerBus, - securityManager: SecurityManager, - appName: String, - basePath: String, - lastUpdateTime: Option[Long], - startTime: Long): SparkUI = { - val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, - lastUpdateTime = lastUpdateTime, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI - } - /** - * Create a new Spark UI. - * - * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs. - * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the - * web UI will create and register its own JobProgressListener. + * Create a new UI backed by an AppStatusStore. */ - private def create( + def create( sc: Option[SparkContext], + store: AppStatusStore, conf: SparkConf, addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, - basePath: String = "", - jobProgressListener: Option[JobProgressListener] = None, + basePath: String, + startTime: Long, lastUpdateTime: Option[Long] = None, - startTime: Long): SparkUI = { + appSparkVersion: String = org.apache.spark.SPARK_VERSION): SparkUI = { - val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { + val jobProgressListener = sc.map(_.jobProgressListener).getOrElse { val listener = new JobProgressListener(conf) addListenerFn(listener) listener } - val environmentListener = new EnvironmentListener val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) @@ -228,8 +196,9 @@ private[spark] object SparkUI { addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, _jobProgressListener, storageListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime) + new SparkUI(store, sc, conf, securityManager, environmentListener, storageStatusListener, + executorsListener, jobProgressListener, storageListener, operationGraphListener, + appName, basePath, lastUpdateTime, startTime, appSparkVersion) } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index 6e50e84549047..44f9c566a380d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -18,15 +18,11 @@ package org.apache.spark.deploy.history import java.util.{Date, NoSuchElementException} -import javax.servlet.Filter import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.collection.mutable -import scala.collection.mutable.ListBuffer import com.codahale.metrics.Counter -import com.google.common.cache.LoadingCache -import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.ServletContextHandler import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -39,23 +35,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.status.api.v1.{ApplicationAttemptInfo => AttemptInfo, ApplicationInfo} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.ManualClock class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar with Matchers { - /** - * subclass with access to the cache internals - * @param retainedApplications number of retained applications - */ - class TestApplicationCache( - operations: ApplicationCacheOperations = new StubCacheOperations(), - retainedApplications: Int, - clock: Clock = new ManualClock(0)) - extends ApplicationCache(operations, retainedApplications, clock) { - - def cache(): LoadingCache[CacheKey, CacheEntry] = appCache - } - /** * Stub cache operations. * The state is kept in a map of [[CacheKey]] to [[CacheEntry]], @@ -77,8 +60,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { logDebug(s"getAppUI($appId, $attemptId)") getAppUICount += 1 - instances.get(CacheKey(appId, attemptId)).map( e => - LoadedAppUI(e.ui, () => updateProbe(appId, attemptId, e.probeTime))) + instances.get(CacheKey(appId, attemptId)).map { e => e.loadedUI } } override def attachSparkUI( @@ -96,10 +78,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = putAppUI(appId, attemptId, completed, started, ended, timestamp) - attachSparkUI(appId, attemptId, ui, completed) + ended: Long): LoadedAppUI = { + val ui = putAppUI(appId, attemptId, completed, started, ended) + attachSparkUI(appId, attemptId, ui.ui, completed) ui } @@ -108,23 +89,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = newUI(appId, attemptId, completed, started, ended) - putInstance(appId, attemptId, ui, completed, timestamp) + ended: Long): LoadedAppUI = { + val ui = LoadedAppUI(newUI(appId, attemptId, completed, started, ended)) + instances(CacheKey(appId, attemptId)) = new CacheEntry(ui, completed) ui } - def putInstance( - appId: String, - attemptId: Option[String], - ui: SparkUI, - completed: Boolean, - timestamp: Long): Unit = { - instances += (CacheKey(appId, attemptId) -> - new CacheEntry(ui, completed, () => updateProbe(appId, attemptId, timestamp), timestamp)) - } - /** * Detach a reconstructed UI * @@ -146,23 +116,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attached.get(CacheKey(appId, attemptId)) } - /** - * The update probe. - * @param appId application to probe - * @param attemptId attempt to probe - * @param updateTime timestamp of this UI load - */ - private[history] def updateProbe( - appId: String, - attemptId: Option[String], - updateTime: Long)(): Boolean = { - updateProbeCount += 1 - logDebug(s"isUpdated($appId, $attemptId, ${updateTime})") - val entry = instances.get(CacheKey(appId, attemptId)).get - val updated = entry.probeTime > updateTime - logDebug(s"entry = $entry; updated = $updated") - updated - } } /** @@ -210,15 +163,13 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val now = clock.getTimeMillis() // add the entry - operations.putAppUI(app1, None, true, now, now, now) + operations.putAppUI(app1, None, true, now, now) // make sure its local operations.getAppUI(app1, None).get operations.getAppUICount = 0 // now expect it to be found - val cacheEntry = cache.lookupCacheEntry(app1, None) - assert(1 === cacheEntry.probeTime) - assert(cacheEntry.completed) + cache.withSparkUI(app1, None) { _ => } // assert about queries made of the operations assert(1 === operations.getAppUICount, "getAppUICount") assert(1 === operations.attachCount, "attachCount") @@ -236,8 +187,8 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar assert(0 === operations.detachCount, "attachCount") // evict the entry - operations.putAndAttach("2", None, true, time2, time2, time2) - operations.putAndAttach("3", None, true, time2, time2, time2) + operations.putAndAttach("2", None, true, time2, time2) + operations.putAndAttach("3", None, true, time2, time2) cache.get("2") cache.get("3") @@ -248,7 +199,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val appId = "app1" val attemptId = Some("_01") val time3 = clock.getTimeMillis() - operations.putAppUI(appId, attemptId, false, time3, 0, time3) + operations.putAppUI(appId, attemptId, false, time3, 0) // expect an error here assertNotFound(appId, None) } @@ -256,10 +207,11 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Test that if an attempt ID is set, it must be used in lookups") { val operations = new StubCacheOperations() val clock = new ManualClock(1) - implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = 10, + clock = clock) val appId = "app1" val attemptId = Some("_01") - operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0, 0) + operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0) assertNotFound(appId, None) } @@ -271,50 +223,29 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Incomplete apps refreshed") { val operations = new StubCacheOperations() val clock = new ManualClock(50) - val window = 500 - implicit val cache = new ApplicationCache(operations, retainedApplications = 5, clock = clock) + implicit val cache = new ApplicationCache(operations, 5, clock) val metrics = cache.metrics // add the incomplete app // add the entry val started = clock.getTimeMillis() val appId = "app1" val attemptId = Some("001") - operations.putAppUI(appId, attemptId, false, started, 0, started) - val firstEntry = cache.lookupCacheEntry(appId, attemptId) - assert(started === firstEntry.probeTime, s"timestamp in $firstEntry") - assert(!firstEntry.completed, s"entry is complete: $firstEntry") - assertMetric("lookupCount", metrics.lookupCount, 1) + val initialUI = operations.putAndAttach(appId, attemptId, false, started, 0) + val firstUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assertMetric("lookupCount", metrics.lookupCount, 1) assert(0 === operations.updateProbeCount, "expected no update probe on that first get") - val checkTime = window * 2 - clock.setTime(checkTime) - val entry3 = cache.lookupCacheEntry(appId, attemptId) - assert(firstEntry !== entry3, s"updated entry test from $cache") + // Invalidate the first entry to trigger a re-load. + initialUI.invalidate() + + // Update the UI in the stub so that a new one is provided to the cache. + operations.putAppUI(appId, attemptId, true, started, started + 10) + + val updatedUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assert(firstUI !== updatedUI, s"expected updated UI") assertMetric("lookupCount", metrics.lookupCount, 2) - assertMetric("updateProbeCount", metrics.updateProbeCount, 1) - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 0) - assert(1 === operations.updateProbeCount, s"refresh count in $cache") - assert(0 === operations.detachCount, s"detach count") - assert(entry3.probeTime === checkTime) - - val updateTime = window * 3 - // update the cached value - val updatedApp = operations.putAppUI(appId, attemptId, true, started, updateTime, updateTime) - val endTime = window * 10 - clock.setTime(endTime) - logDebug(s"Before operation = $cache") - val entry5 = cache.lookupCacheEntry(appId, attemptId) - assertMetric("lookupCount", metrics.lookupCount, 3) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) - // the update was triggered - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 1) - assert(updatedApp === entry5.ui, s"UI {$updatedApp} did not match entry {$entry5} in $cache") - - // at which point, the refreshes stop - clock.setTime(window * 20) - assertCacheEntryEquals(appId, attemptId, entry5) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + assert(1 === operations.detachCount, s"detach count") } /** @@ -337,27 +268,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } } - /** - * Look up the cache entry and assert that it matches in the expected value. - * This assertion works if the two CacheEntries are different -it looks at the fields. - * UI are compared on object equality; the timestamp and completed flags directly. - * @param appId application ID - * @param attemptId attempt ID - * @param expected expected value - * @param cache app cache - */ - def assertCacheEntryEquals( - appId: String, - attemptId: Option[String], - expected: CacheEntry) - (implicit cache: ApplicationCache): Unit = { - val actual = cache.lookupCacheEntry(appId, attemptId) - val errorText = s"Expected get($appId, $attemptId) -> $expected, but got $actual from $cache" - assert(expected.ui === actual.ui, errorText + " SparkUI reference") - assert(expected.completed === actual.completed, errorText + " -completed flag") - assert(expected.probeTime === actual.probeTime, errorText + " -timestamp") - } - /** * Assert that a key wasn't found in cache or loaded. * @@ -370,14 +280,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar appId: String, attemptId: Option[String]) (implicit cache: ApplicationCache): Unit = { - val ex = intercept[UncheckedExecutionException] { + val ex = intercept[NoSuchElementException] { cache.get(appId, attemptId) } - var cause = ex.getCause - assert(cause !== null) - if (!cause.isInstanceOf[NoSuchElementException]) { - throw cause - } } test("Large Scale Application Eviction") { @@ -385,12 +290,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val clock = new ManualClock(0) val size = 5 // only two entries are retained, so we expect evictions to occur on lookups - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = size, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = size, + clock = clock) val attempt1 = Some("01") - val ids = new ListBuffer[String]() + val ids = new mutable.ListBuffer[String]() // build a list of applications val count = 100 for (i <- 1 to count ) { @@ -398,7 +303,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar ids += appId clock.advance(10) val t = clock.getTimeMillis() - operations.putAppUI(appId, attempt1, true, t, t, t) + operations.putAppUI(appId, attempt1, true, t, t) } // now go through them in sequence reading them, expect evictions ids.foreach { id => @@ -413,20 +318,19 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Attempts are Evicted") { val operations = new StubCacheOperations() - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = 4) + implicit val cache = new ApplicationCache(operations, 4, new ManualClock()) val metrics = cache.metrics val appId = "app1" val attempt1 = Some("01") val attempt2 = Some("02") val attempt3 = Some("03") - operations.putAppUI(appId, attempt1, true, 100, 110, 110) - operations.putAppUI(appId, attempt2, true, 200, 210, 210) - operations.putAppUI(appId, attempt3, true, 300, 310, 310) + operations.putAppUI(appId, attempt1, true, 100, 110) + operations.putAppUI(appId, attempt2, true, 200, 210) + operations.putAppUI(appId, attempt3, true, 300, 310) val attempt4 = Some("04") - operations.putAppUI(appId, attempt4, true, 400, 410, 410) + operations.putAppUI(appId, attempt4, true, 400, 410) val attempt5 = Some("05") - operations.putAppUI(appId, attempt5, true, 500, 510, 510) + operations.putAppUI(appId, attempt5, true, 500, 510) def expectLoadAndEvictionCounts(expectedLoad: Int, expectedEvictionCount: Int): Unit = { assertMetric("loadCount", metrics.loadCount, expectedLoad) @@ -457,20 +361,14 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } - test("Instantiate Filter") { - // this is a regression test on the filter being constructable - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val instance = clazz.newInstance() - instance shouldBe a [Filter] - } - test("redirect includes query params") { - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val filter = clazz.newInstance().asInstanceOf[ApplicationCacheCheckFilter] - filter.appId = "local-123" + val operations = new StubCacheOperations() + val ui = operations.putAndAttach("foo", None, true, 0, 10) val cache = mock[ApplicationCache] - when(cache.checkForUpdates(any(), any())).thenReturn(true) - ApplicationCacheCheckFilterRelay.setApplicationCache(cache) + when(cache.operations).thenReturn(operations) + val filter = new ApplicationCacheCheckFilter(new CacheKey("foo", None), ui, cache) + ui.invalidate() + val request = mock[HttpServletRequest] when(request.getMethod()).thenReturn("GET") when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 03bd3eaf579f3..86c8cdf43258c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.security.GroupMappingServiceProvider +import org.apache.spark.status.AppStatusStore import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -612,7 +613,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Manually overwrite the version in the listing db; this should cause the new provider to // discard all data because the versions don't match. val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, - conf.get(LOCAL_STORE_DIR).get) + AppStatusStore.CURRENT_VERSION, conf.get(LOCAL_STORE_DIR).get) oldProvider.listing.setMetadata(meta) oldProvider.stop() @@ -620,6 +621,43 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(mistatchedVersionProvider.listing.count(classOf[ApplicationInfoWrapper]) === 0) } + test("invalidate cached UI") { + val provider = new FsHistoryProvider(createTestConf()) + val appId = "new1" + + // Write an incomplete app log. + val appLog = newLogFile(appId, None, inProgress = true) + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None) + ) + provider.checkForLogs() + + // Load the app UI. + val oldUI = provider.getAppUI(appId, None) + assert(oldUI.isDefined) + intercept[NoSuchElementException] { + oldUI.get.ui.store.job(0) + } + + // Add more info to the app log, and trigger the provider to update things. + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + provider.checkForLogs() + + // Manually detach the old UI; ApplicationCache would do this automatically in a real SHS + // when the app's UI was requested. + provider.onUIDetached(appId, None, oldUI.get.ui) + + // Load the UI again and make sure we can get the new info added to the logs. + val freshUI = provider.getAppUI(appId, None) + assert(freshUI.isDefined) + assert(freshUI != oldUI) + freshUI.get.ui.store.job(0) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index c11543a4b3ba2..010a8dd004d4f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -72,6 +72,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var port: Int = -1 def init(extraConf: (String, String)*): Unit = { + Utils.deleteRecursively(storeDir) + assert(storeDir.mkdir()) val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") @@ -292,21 +294,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) - server.stop() - - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir) - .set("spark.history.fs.update.interval", "0") - .set("spark.testing", "true") - .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - - provider = new FsHistoryProvider(conf) - provider.checkForLogs() - val securityManager = HistoryServer.createSecurityManager(conf) - - server = new HistoryServer(conf, provider, securityManager, 18080) - server.initialize() - server.bind() + stop() + init() val port = server.boundPort @@ -375,8 +364,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } test("incomplete apps get refreshed") { - server.stop() - implicit val webDriver: WebDriver = new HtmlUnitDriver implicit val formats = org.json4s.DefaultFormats @@ -386,6 +373,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // a new conf is used with the background thread set and running at its fastest // allowed refresh rate (1Hz) + stop() val myConf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) .set("spark.eventLog.dir", logDir.getAbsolutePath) @@ -418,7 +406,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - server = new HistoryServer(myConf, provider, securityManager, 18080) + server = new HistoryServer(myConf, provider, securityManager, 0) server.initialize() server.bind() val port = server.boundPort @@ -464,7 +452,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers rootAppPage should not be empty def getAppUI: SparkUI = { - provider.getAppUI(appId, None).get.ui + server.withSparkUI(appId, None) { ui => ui } } // selenium isn't that useful on failures...add our own reporting @@ -519,7 +507,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getNumJobs("") should be (1) getNumJobs("/jobs") should be (1) getNumJobsRestful() should be (1) - assert(metrics.lookupCount.getCount > 1, s"lookup count too low in $metrics") + assert(metrics.lookupCount.getCount > 0, s"lookup count too low in $metrics") // dump state before the next bit of test, which is where update // checking really gets stressed diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 6f7a0c14dd684..7ac1ce19f8ddf 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.status import java.io.File -import java.util.{Date, Properties} +import java.lang.{Integer => JInteger, Long => JLong} +import java.util.{Arrays, Date, Properties} import scala.collection.JavaConverters._ import scala.reflect.{classTag, ClassTag} @@ -36,6 +37,10 @@ import org.apache.spark.util.kvstore._ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + import config._ + + private val conf = new SparkConf().set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + private var time: Long = _ private var testDir: File = _ private var store: KVStore = _ @@ -52,7 +57,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } test("scheduler events") { - val listener = new AppStatusListener(store) + val listener = new AppStatusListener(store, conf, true) // Start the application. time += 1 @@ -174,6 +179,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) + assert(wrapper.stageId === stages.head.stageId) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + + val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + -1L: JLong) + assert(Arrays.equals(wrapper.runtime, runtime)) + assert(wrapper.info.index === task.index) assert(wrapper.info.attempt === task.attemptNumber) assert(wrapper.info.launchTime === new Date(task.launchTime)) @@ -510,7 +523,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } test("storage events") { - val listener = new AppStatusListener(store) + val listener = new AppStatusListener(store, conf, true) val maxMemory = 42L // Register a couple of block managers. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 45b8870f3b62f..99cac34c85ebc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { lazy val v23excludes = v22excludes ++ Seq( // SPARK-18085: Better History Server scalability for many / large applications ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 65a8bf6036fe41a53b4b1e4298fa35d7fa4e9970 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 6 Nov 2017 08:58:42 -0800 Subject: [PATCH 185/712] [SPARK-22315][SPARKR] Warn if SparkR package version doesn't match SparkContext ## What changes were proposed in this pull request? This PR adds a check between the R package version used and the version reported by SparkContext running in the JVM. The goal here is to warn users when they have a R package downloaded from CRAN and are using that to connect to an existing Spark cluster. This is raised as a warning rather than an error as users might want to use patch versions interchangeably (e.g., 2.1.3 with 2.1.2 etc.) ## How was this patch tested? Manually by changing the `DESCRIPTION` file Author: Shivaram Venkataraman Closes #19624 from shivaram/sparkr-version-check. --- R/pkg/R/sparkR.R | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 81507ea7186af..fb5f1d21fc723 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -420,6 +420,18 @@ sparkR.session <- function( enableHiveSupport) assign(".sparkRsession", sparkSession, envir = .sparkREnv) } + + # Check if version number of SparkSession matches version number of SparkR package + jvmVersion <- callJMethod(sparkSession, "version") + # Remove -SNAPSHOT from jvm versions + jvmVersionStrip <- gsub("-SNAPSHOT", "", jvmVersion) + rPackageVersion <- paste0(packageVersion("SparkR")) + + if (jvmVersionStrip != rPackageVersion) { + warning(paste("Version mismatch between Spark JVM and SparkR package. JVM version was", + jvmVersion, ", while R package version was", rPackageVersion)) + } + sparkSession } From 5014d6e2568021e6958eddc9cfb4c512ed7424a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Nov 2017 22:25:11 +0100 Subject: [PATCH 186/712] [SPARK-22078][SQL] clarify exception behaviors for all data source v2 interfaces ## What changes were proposed in this pull request? clarify exception behaviors for all data source v2 interfaces. ## How was this patch tested? document change only Author: Wenchen Fan Closes #19623 from cloud-fan/data-source-exception. --- .../spark/sql/sources/v2/ReadSupport.java | 3 ++ .../sql/sources/v2/ReadSupportWithSchema.java | 3 ++ .../spark/sql/sources/v2/WriteSupport.java | 3 ++ .../sql/sources/v2/reader/DataReader.java | 11 +++++++- .../sources/v2/reader/DataSourceV2Reader.java | 9 ++++++ .../spark/sql/sources/v2/reader/ReadTask.java | 12 +++++++- .../SupportsPushDownCatalystFilters.java | 1 - .../v2/reader/SupportsScanUnsafeRow.java | 1 - .../sources/v2/writer/DataSourceV2Writer.java | 28 ++++++++++++------- .../sql/sources/v2/writer/DataWriter.java | 20 +++++++++---- .../sources/v2/writer/DataWriterFactory.java | 3 ++ .../v2/writer/SupportsWriteInternalRow.java | 3 -- 12 files changed, 74 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ee489ad0f608f..948e20bacf4a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,6 +30,9 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 74e81a2c84d68..b69c6bed8d1b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,6 +35,9 @@ public interface ReadSupportWithSchema { /** * Create a {@link DataSourceV2Reader} to scan the data from this data source. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param schema the full schema of this data source reader. Full schema usually maps to the * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 8fdfdfd19ea1e..1e3b644d8c4ae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,6 +35,9 @@ public interface WriteSupport { * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param jobId A unique string for the writing job. It's possible that there are many writing * jobs running at the same time, and the returned {@link DataSourceV2Writer} can * use this job id to distinguish itself from other jobs. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 52bb138673fc9..8f58c865b6201 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import java.io.Closeable; +import java.io.IOException; import org.apache.spark.annotation.InterfaceStability; @@ -34,11 +35,19 @@ public interface DataReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + * + * @throws IOException if failure happens during disk/network IO like reading files. */ - boolean next(); + boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 88c3219a75c1d..95ee4a8278322 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -40,6 +40,9 @@ * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. * Names of these interfaces start with `SupportsScan`. * + * If an exception was throw when applying any of these query optimizations, the action would fail + * and no Spark job was submitted. + * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark * issues the scan request and does the actual data reading. @@ -50,6 +53,9 @@ public interface DataSourceV2Reader { /** * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ StructType readSchema(); @@ -61,6 +67,9 @@ public interface DataSourceV2Reader { * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ List> createReadTasks(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 44786db419a32..fa161cdb8b347 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -36,7 +36,14 @@ public interface ReadTask extends Serializable { /** * The preferred locations where this read task can run faster, but Spark does not guarantee that * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name of an executor. + * be run on any location. The location is a string representing the host name. + * + * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in + * the returned locations. By default this method returns empty string array, which means this + * task has no location preference. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ default String[] preferredLocations() { return new String[0]; @@ -44,6 +51,9 @@ default String[] preferredLocations() { /** * Returns a data reader to do the actual reading work for this read task. + * + * If this method fails (by throwing an exception), the corresponding Spark task would fail and + * get retried until hitting the maximum retry times. */ DataReader createDataReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index efc42242f4421..f76c687f450c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.expressions.Expression; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 6008fb5f71cc1..b90ec880dc85e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -19,7 +19,6 @@ import java.util.List; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 37bb15f87c59a..fc37b9a516f82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -30,6 +30,9 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * + * If an exception was throw when applying any of these writing optimizations, the action would fail + * and no Spark job was submitted. + * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the * partitions of the input data(RDD). @@ -50,28 +53,33 @@ public interface DataSourceV2Writer { /** * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. */ DataWriterFactory createWriterFactory(); /** * Commits this writing job with a list of commit messages. The commit messages are collected from - * successful data writers and are produced by {@link DataWriter#commit()}. If this method - * fails(throw exception), this writing job is considered to be failed, and - * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible - * to data source readers if this method succeeds. + * successful data writers and are produced by {@link DataWriter#commit()}. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination + * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that, one partition may have multiple committed data writers because of speculative tasks. * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer - * can commit successfully, or have a way to clean up the data of already-committed writers. + * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data + * writer can commit, or have a way to clean up the data of already-committed writers. */ void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed to write the records and aborted, - * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} - * fails. If this method fails(throw exception), the underlying data source may have garbage that - * need to be cleaned manually, but these garbage should not be visible to data source readers. + * Aborts this writing job because some data writers are failed and keep failing when retry, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. * * Unless the abort is triggered by the failure of commit, the given messages should have some * null slots as there maybe only a few data writers that are committed before the abort diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index dc1aab33bdcef..04b03e63de500 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources.v2.writer; +import java.io.IOException; + import org.apache.spark.annotation.InterfaceStability; /** @@ -59,8 +61,10 @@ public interface DataWriter { * * If this method fails (by throwing an exception), {@link #abort()} will be called and this * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - void write(T record); + void write(T record) throws IOException; /** * Commits this writer after all records are written successfully, returns a commit message which @@ -74,8 +78,10 @@ public interface DataWriter { * * If this method fails (by throwing an exception), {@link #abort()} will be called and this * data writer is considered to have been failed. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - WriterCommitMessage commit(); + WriterCommitMessage commit() throws IOException; /** * Aborts this writer if it is failed. Implementations should clean up the data for already @@ -84,9 +90,11 @@ public interface DataWriter { * This method will only be called if there is one record failed to write, or {@link #commit()} * failed. * - * If this method fails(throw exception), the underlying data source may have garbage that need - * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but - * these garbage should not be visible to data source readers. + * If this method fails(by throwing an exception), the underlying data source may have garbage + * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * but these garbage should not be visible to data source readers. + * + * @throws IOException if failure happens during disk/network IO like writing files. */ - void abort(); + void abort() throws IOException; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index fe56cc00d1c7a..18ec792f5a2c9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,6 +35,9 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * + * If this method fails (by throwing an exception), the action would fail and no Spark job was + * submitted. + * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index a8e95901f3b07..3e0518814f458 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; @@ -29,8 +28,6 @@ * changed in the future Spark versions. */ -@InterfaceStability.Evolving -@Experimental @InterfaceStability.Unstable public interface SupportsWriteInternalRow extends DataSourceV2Writer { From 14a32a647a019ed45793784069fbf077b9c45e60 Mon Sep 17 00:00:00 2001 From: Alexander Istomin Date: Tue, 7 Nov 2017 00:47:16 +0100 Subject: [PATCH 187/712] [SPARK-22330][CORE] Linear containsKey operation for serialized maps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …alization. ## What changes were proposed in this pull request? Use non-linear containsKey operation for serialized maps, lookup into underlying map. ## How was this patch tested? unit tests Author: Alexander Istomin Closes #19553 from Whoosh/SPARK-22330. --- .../org/apache/spark/api/java/JavaUtils.scala | 9 +++- .../spark/api/java/JavaUtilsSuite.scala | 49 +++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index d6506231b8d74..fd96052f95d3f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -43,10 +43,17 @@ private[spark] object JavaUtils { override def size: Int = underlying.size + // Delegate to implementation because AbstractMap implementation iterates over whole key set + override def containsKey(key: AnyRef): Boolean = try { + underlying.contains(key.asInstanceOf[A]) + } catch { + case _: ClassCastException => false + } + override def get(key: AnyRef): B = try { underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { - case ex: ClassCastException => null.asInstanceOf[B] + case _: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { diff --git a/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala new file mode 100644 index 0000000000000..8e6e3e0968617 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/java/JavaUtilsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import java.io.Serializable + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite + + +class JavaUtilsSuite extends SparkFunSuite { + + test("containsKey implementation without iteratively entrySet call") { + val src = new scala.collection.mutable.HashMap[Double, String] + val key: Double = 42.5 + val key2 = "key" + + src.put(key, "42") + + val map: java.util.Map[Double, String] = spy(JavaUtils.mapAsSerializableJavaMap(src)) + + assert(map.containsKey(key)) + + // ClassCast checking, shouldn't throw exception + assert(!map.containsKey(key2)) + assert(map.get(key2) == null) + + assert(map.get(key).eq("42")) + assert(map.isInstanceOf[Serializable]) + + verify(map, never()).entrySet() + } +} From 9df08e218cfd4dd91bc407b98528b74f452f34f8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 7 Nov 2017 08:30:58 +0000 Subject: [PATCH 188/712] [SPARK-22454][CORE] ExternalShuffleClient.close() should check clientFactory null ## What changes were proposed in this pull request? `ExternalShuffleClient.close()` should check `clientFactory` null. otherwise it will throw NPE sometimes: ``` 17/11/06 20:08:05 ERROR Utils: Uncaught exception in thread main java.lang.NullPointerException at org.apache.spark.network.shuffle.ExternalShuffleClient.close(ExternalShuffleClient.java:152) at org.apache.spark.storage.BlockManager.stop(BlockManager.scala:1407) at org.apache.spark.SparkEnv.stop(SparkEnv.scala:89) at org.apache.spark.SparkContext$$anonfun$stop$11.apply$mcV$sp(SparkContext.scala:1849) ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #19670 from wangyum/SPARK-22454. --- .../apache/spark/network/shuffle/ExternalShuffleClient.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 510017fee2db5..7ed0b6e93a7a8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -148,6 +148,9 @@ public void registerWithShuffleServer( @Override public void close() { checkInit(); - clientFactory.close(); + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } } } From ed1478cfe173a44876b586943d61cdc93b0869a2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 7 Nov 2017 17:57:47 +0900 Subject: [PATCH 189/712] [BUILD] Close stale PRs Closes #11494 Closes #14158 Closes #16803 Closes #16864 Closes #17455 Closes #17936 Closes #19377 Added: Closes #19380 Closes #18642 Closes #18377 Closes #19632 Added: Closes #14471 Closes #17402 Closes #17953 Closes #18607 Also cc srowen vanzin HyukjinKwon gatorsmile cloud-fan to see if you have other PRs to close. Author: Xingbo Jiang Closes #19669 from jiangxb1987/stale-prs. From 160a540610051ee3233ea24102533b08b69f03fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 7 Nov 2017 19:45:34 +0900 Subject: [PATCH 190/712] [SPARK-22376][TESTS] Makes dev/run-tests.py script compatible with Python 3 ## What changes were proposed in this pull request? This PR proposes to fix `dev/run-tests.py` script to support Python 3. Here are some backgrounds. Up to my knowledge, In Python 2, - `unicode` is NOT `str` in Python 2 (`type("foo") != type(u"foo")`). - `str` has an alias, `bytes` in Python 2 (`type("foo") == type(b"foo")`). In Python 3, - `unicode` was (roughly) replaced by `str` in Python 3 (`type("foo") == type(u"foo")`). - `str` is NOT `bytes` in Python 3 (`type("foo") != type(b"foo")`). So, this PR fixes: 1. Use `b''` instead of `''` so that both `str` in Python 2 and `bytes` in Python 3 can be hanlded. `sbt_proc.stdout.readline()` returns `str` (which has an alias, `bytes`) in Python 2 and `bytes` in Python 3 2. Similarily, use `b''` instead of `''` so that both `str` in Python 2 and `bytes` in Python 3 can be hanlded. `re.compile` with `str` pattern does not seem supporting to match `bytes` in Python 3: Actually, this change is recommended up to my knowledge - https://docs.python.org/3/howto/pyporting.html#text-versus-binary-data: > Mark all binary literals with a b prefix, textual literals with a u prefix ## How was this patch tested? I manually tested this via Python 3 with few additional changes to reduce the elapsed time. Author: hyukjinkwon Closes #19665 from HyukjinKwon/SPARK-22376. --- dev/run-tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 72d148d7ea0fb..ef0e788a91606 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -276,9 +276,9 @@ def exec_sbt(sbt_args=()): sbt_cmd = [os.path.join(SPARK_HOME, "build", "sbt")] + sbt_args - sbt_output_filter = re.compile("^.*[info].*Resolving" + "|" + - "^.*[warn].*Merging" + "|" + - "^.*[info].*Including") + sbt_output_filter = re.compile(b"^.*[info].*Resolving" + b"|" + + b"^.*[warn].*Merging" + b"|" + + b"^.*[info].*Including") # NOTE: echo "q" is needed because sbt on encountering a build file # with failure (either resolution or compilation) prompts the user for @@ -289,7 +289,7 @@ def exec_sbt(sbt_args=()): stdin=echo_proc.stdout, stdout=subprocess.PIPE) echo_proc.wait() - for line in iter(sbt_proc.stdout.readline, ''): + for line in iter(sbt_proc.stdout.readline, b''): if not sbt_output_filter.match(line): print(line, end='') retcode = sbt_proc.wait() From d5202259d9aa9ad95d572af253bf4a722b7b437a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Nov 2017 09:33:52 -0800 Subject: [PATCH 191/712] [SPARK-21127][SQL][FOLLOWUP] fix a config name typo ## What changes were proposed in this pull request? `spark.sql.statistics.autoUpdate.size` should be `spark.sql.statistics.size.autoUpdate.enabled`. The previous name is confusing as users may treat it as a size config. This config is in master branch only, no backward compatibility issue. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19667 from cloud-fan/minor. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- .../spark/sql/execution/command/CommandUtils.scala | 2 +- .../org/apache/spark/sql/execution/command/ddl.scala | 2 +- .../apache/spark/sql/StatisticsCollectionSuite.scala | 10 +++++----- .../org/apache/spark/sql/hive/StatisticsSuite.scala | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ede116e964a03..a04f8778079de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -812,8 +812,8 @@ object SQLConf { .doubleConf .createWithDefault(0.05) - val AUTO_UPDATE_SIZE = - buildConf("spark.sql.statistics.autoUpdate.size") + val AUTO_SIZE_UPDATE_ENABLED = + buildConf("spark.sql.statistics.size.autoUpdate.enabled") .doc("Enables automatic update for table size once table's data is changed. Note that if " + "the total number of files of the table is very large, this can be expensive and slow " + "down data change commands.") @@ -1206,7 +1206,7 @@ class SQLConf extends Serializable with Logging { def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) - def autoUpdateSize: Boolean = getConf(SQLConf.AUTO_UPDATE_SIZE) + def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index b22958d59336c..1a0d67fc71fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -36,7 +36,7 @@ object CommandUtils extends Logging { def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { if (table.stats.nonEmpty) { val catalog = sparkSession.sessionState.catalog - if (sparkSession.sessionState.conf.autoUpdateSize) { + if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val newTable = catalog.getTableMetadata(table.identifier) val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) val newStats = CatalogStatistics(sizeInBytes = newSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index a9cd65e3242c9..568567aa8ea88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -442,7 +442,7 @@ case class AlterTableAddPartitionCommand( catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) if (table.stats.nonEmpty) { - if (sparkSession.sessionState.conf.autoUpdateSize) { + if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val addedSize = parts.map { part => CommandUtils.calculateLocationSize(sparkSession.sessionState, table.identifier, part.storage.locationUri) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 2fc92f4aff92e..7247c3a876df3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -216,7 +216,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) // analyze to get initial stats @@ -252,7 +252,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after insert command for datasource table") { val table = "change_stats_insert_datasource_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") // analyze to get initial stats @@ -285,7 +285,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after inserts") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).write.saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") @@ -302,7 +302,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after table truncation") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { spark.range(100).write.saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") @@ -318,7 +318,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("invalidation of tableRelationCache after alter table add partition") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTempDir { dir => withTable(table) { val path = dir.getCanonicalPath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index b9a5ad7657134..9e8fc32a05471 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -755,7 +755,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after insert command for hive table") { val table = s"change_stats_insert_hive_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i int, j string)") // analyze to get initial stats @@ -783,7 +783,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after load data command") { val table = "change_stats_load_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") // analyze to get initial stats @@ -817,7 +817,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("change stats after add/drop partition command") { val table = "change_stats_part_table" Seq(false, true).foreach { autoUpdate => - withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") // table has two partitions initially From 1d341042d6948e636643183da9bf532268592c6a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 7 Nov 2017 21:32:37 +0100 Subject: [PATCH 192/712] [SPARK-22417][PYTHON] Fix for createDataFrame from pandas.DataFrame with timestamp ## What changes were proposed in this pull request? Currently, a pandas.DataFrame that contains a timestamp of type 'datetime64[ns]' when converted to a Spark DataFrame with `createDataFrame` will interpret the values as LongType. This fix will check for a timestamp type and convert it to microseconds which will allow Spark to read as TimestampType. ## How was this patch tested? Added unit test to verify Spark schema is expected for TimestampType and DateType when created from pandas Author: Bryan Cutler Closes #19646 from BryanCutler/pyspark-non-arrow-createDataFrame-ts-fix-SPARK-22417. --- python/pyspark/sql/session.py | 49 ++++++++++++++++++++++++++++++++--- python/pyspark/sql/tests.py | 15 +++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c3dc1a46fd3c1..d1d0b8b8fe5d9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -23,6 +23,7 @@ if sys.version >= '3': basestring = unicode = str + xrange = range else: from itertools import imap as map @@ -416,6 +417,50 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema + def _get_numpy_record_dtypes(self, rec): + """ + Used when converting a pandas.DataFrame to Spark using to_records(), this will correct + the dtypes of records so they can be properly loaded into Spark. + :param rec: a numpy record to check dtypes + :return corrected dtypes for a numpy.record or None if no correction needed + """ + import numpy as np + cur_dtypes = rec.dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, + # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 + if curr_type == np.dtype('datetime64[ns]'): + curr_type = 'datetime64[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + return record_type_list if has_rec_fix else None + + def _convert_from_pandas(self, pdf, schema): + """ + Convert a pandas.DataFrame to list of records that can be used to make a DataFrame + :return tuple of list of records and schema + """ + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in pdf.columns] + + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_type_list = self._get_numpy_record_dtypes(np_records[0]) + if record_type_list is not None: + return [r.astype(record_type_list).tolist() for r in np_records], schema + + # Convert list of numpy records to python lists + return [r.tolist() for r in np_records], schema + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -512,9 +557,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] + data, schema = self._convert_from_pandas(data, schema) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 483f39aeef66a..eb0d4e29a5978 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2592,6 +2592,21 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_create_dataframe_from_pandas_with_timestamp(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + # test types are inferred correctly without specifying schema + df = self.spark.createDataFrame(pdf) + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + # test with schema will accept pdf as input + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + class HiveSparkSubmitTests(SparkSubmitTests): From 0846a44736d9a71ba3234ad5de4c8de9e7fe9f6c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 7 Nov 2017 21:57:43 +0100 Subject: [PATCH 193/712] [SPARK-22464][SQL] No pushdown for Hive metastore partition predicates containing null-safe equality ## What changes were proposed in this pull request? `<=>` is not supported by Hive metastore partition predicate pushdown. We should not push down it to Hive metastore when they are be using in partition predicates. ## How was this patch tested? Added a test case Author: gatorsmile Closes #19682 from gatorsmile/fixLimitPushDown. --- .../spark/sql/hive/client/HiveShim.scala | 29 ++++++++++++++----- .../sql/hive/client/HiveClientSuite.scala | 9 ++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5c1ff2b76fdaa..bd1b300416990 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -592,6 +592,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } } + + /** + * An extractor that matches all binary comparison operators except null-safe equality. + * + * Null-safe equality is not supported by Hive metastore partition predicate pushdown + */ + object SpecialBinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case _: EqualNullSafe => None + case _ => Some((e.left, e.right)) + } + } + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala @@ -600,14 +613,14 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { .map(col => col.getName).toSet filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + case op @ SpecialBinaryComparison(Literal(v, _: IntegralType), a: Attribute) => s"$v ${op.symbol} ${a.name}" - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + case op @ SpecialBinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") @@ -666,16 +679,16 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case InSet(a: Attribute, ExtractableValues(values)) if !varcharKeys.contains(a.name) && values.nonEmpty => convertInToOr(a, values) - case op @ BinaryComparison(a: Attribute, ExtractableLiteral(value)) + case op @ SpecialBinaryComparison(a: Attribute, ExtractableLiteral(value)) if !varcharKeys.contains(a.name) => s"${a.name} ${op.symbol} $value" - case op @ BinaryComparison(ExtractableLiteral(value), a: Attribute) + case op @ SpecialBinaryComparison(ExtractableLiteral(value), a: Attribute) if !varcharKeys.contains(a.name) => s"$value ${op.symbol} ${a.name}" - case op @ And(expr1, expr2) + case And(expr1, expr2) if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") - case op @ Or(expr1, expr2) + case Or(expr1, expr2) if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => s"(${convert(expr1)} or ${convert(expr2)})" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 3eedcf7e0874e..ce53acef51503 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -78,6 +78,15 @@ class HiveClientSuite(version: String) assert(filteredPartitions.size == testPartitionCount) } + test("getPartitionsByFilter: ds<=>20170101") { + // Should return all partitions where <=> is not supported + testMetastorePartitionFiltering( + "ds<=>20170101", + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( "ds=20170101", From 7475a9655cd23dbf667e9543cf0818243e59f998 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Nov 2017 16:03:24 -0600 Subject: [PATCH 194/712] [SPARK-20645][CORE] Port environment page to new UI backend. This change modifies the status listener to collect the information needed to render the envionment page, and populates that page and the API with information collected by the listener. Tested with existing and added unit tests. Author: Marcelo Vanzin Closes #19677 from vanzin/SPARK-20645. --- .../spark/status/AppStatusListener.scala | 18 ++ .../apache/spark/status/AppStatusStore.scala | 9 + .../v1/ApplicationEnvironmentResource.scala | 15 +- .../org/apache/spark/status/storeTypes.scala | 11 + .../scala/org/apache/spark/ui/SparkUI.scala | 28 +- .../apache/spark/ui/env/EnvironmentPage.scala | 31 +- .../apache/spark/ui/env/EnvironmentTab.scala | 56 ---- .../app_environment_expectation.json | 281 ++++++++++++++++++ .../deploy/history/HistoryServerSuite.scala | 4 +- .../spark/status/AppStatusListenerSuite.scala | 40 +++ project/MimaExcludes.scala | 1 + 11 files changed, 402 insertions(+), 92 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala create mode 100644 core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index cd43612fae357..424a1159a875c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -88,6 +88,24 @@ private[spark] class AppStatusListener( kvstore.write(new ApplicationInfoWrapper(appInfo)) } + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + val details = event.environmentDetails + + val jvmInfo = Map(details("JVM Information"): _*) + val runtime = new v1.RuntimeInfo( + jvmInfo.get("Java Version").orNull, + jvmInfo.get("Java Home").orNull, + jvmInfo.get("Scala Version").orNull) + + val envInfo = new v1.ApplicationEnvironmentInfo( + runtime, + details.getOrElse("Spark Properties", Nil), + details.getOrElse("System Properties", Nil), + details.getOrElse("Classpath Entries", Nil)) + + kvstore.write(new ApplicationEnvironmentInfoWrapper(envInfo)) + } + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { val old = appInfo.attempts.head val attempt = new v1.ApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 2927a3227cbef..d6b5d2661e8ee 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -33,6 +33,15 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} */ private[spark] class AppStatusStore(store: KVStore) { + def applicationInfo(): v1.ApplicationInfo = { + store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info + } + + def environmentInfo(): v1.ApplicationEnvironmentInfo = { + val klass = classOf[ApplicationEnvironmentInfoWrapper] + store.read(klass, klass.getName()).info + } + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) if (!statuses.isEmpty()) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala index 739a8aceae861..e702f8aa2ef2d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala @@ -26,20 +26,7 @@ private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { @GET def getEnvironmentInfo(): ApplicationEnvironmentInfo = { - val listener = ui.environmentListener - listener.synchronized { - val jvmInfo = Map(listener.jvmInformation: _*) - val runtime = new RuntimeInfo( - jvmInfo("Java Version"), - jvmInfo("Java Home"), - jvmInfo("Scala Version")) - - new ApplicationEnvironmentInfo( - runtime, - listener.sparkProperties, - listener.systemProperties, - listener.classpathEntries) - } + ui.store.environmentInfo() } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index a445435809f3a..23e9a360ddc02 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -34,6 +34,17 @@ private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { } +private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvironmentInfo) { + + /** + * There's always a single ApplicationEnvironmentInfo object per application, so this + * ID doesn't need to be dynamic. But the KVStore API requires an ID. + */ + @JsonIgnore @KVIndex + def id: String = classOf[ApplicationEnvironmentInfoWrapper].getName() + +} + private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index ee645f6bf8a7a..43b57a1630aa9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -25,11 +25,10 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, - UIRoot} +import org.apache.spark.status.api.v1._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} +import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener @@ -44,7 +43,6 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val environmentListener: EnvironmentListener, val storageStatusListener: StorageStatusListener, val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, @@ -73,7 +71,7 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) attachTab(stagesTab) attachTab(new StorageTab(this)) - attachTab(new EnvironmentTab(this)) + attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) @@ -88,9 +86,13 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.sparkUser - .orElse(environmentListener.systemProperties.toMap.get("user.name")) - .getOrElse("") + try { + Option(store.applicationInfo().attempts.head.sparkUser) + .orElse(store.environmentInfo().systemProperties.toMap.get("user.name")) + .getOrElse("") + } catch { + case _: NoSuchElementException => "" + } } def getAppName: String = appName @@ -143,6 +145,7 @@ private[spark] class SparkUI private ( def setStreamingJobProgressListener(sparkListener: SparkListener): Unit = { streamingJobProgressListener = Option(sparkListener) } + } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) @@ -184,21 +187,20 @@ private[spark] object SparkUI { addListenerFn(listener) listener } - val environmentListener = new EnvironmentListener + val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - addListenerFn(environmentListener) addListenerFn(storageStatusListener) addListenerFn(executorsListener) addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, jobProgressListener, storageListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, storageStatusListener, executorsListener, + jobProgressListener, storageListener, operationGraphListener, appName, basePath, + lastUpdateTime, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index b11f8f1555f17..43adab7a35d65 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -21,22 +21,31 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.SparkConf +import org.apache.spark.status.AppStatusStore +import org.apache.spark.ui._ import org.apache.spark.util.Utils -private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { - private val listener = parent.listener +private[ui] class EnvironmentPage( + parent: EnvironmentTab, + conf: SparkConf, + store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + val appEnv = store.environmentInfo() + val jvmInformation = Map( + "Java Version" -> appEnv.runtime.javaVersion, + "Java Home" -> appEnv.runtime.javaHome, + "Scala Version" -> appEnv.runtime.scalaVersion) + val runtimeInformationTable = UIUtils.listingTable( - propertyHeader, jvmRow, listener.jvmInformation, fixedWidth = true) + propertyHeader, jvmRow, jvmInformation, fixedWidth = true) val sparkPropertiesTable = UIUtils.listingTable(propertyHeader, propertyRow, - Utils.redact(parent.conf, listener.sparkProperties), fixedWidth = true) - + Utils.redact(conf, appEnv.sparkProperties.toSeq), fixedWidth = true) val systemPropertiesTable = UIUtils.listingTable( - propertyHeader, propertyRow, listener.systemProperties, fixedWidth = true) + propertyHeader, propertyRow, appEnv.systemProperties, fixedWidth = true) val classpathEntriesTable = UIUtils.listingTable( - classPathHeaders, classPathRow, listener.classpathEntries, fixedWidth = true) + classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content =

    Runtime Information

    {runtimeInformationTable} @@ -54,3 +63,9 @@ private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} private def classPathRow(data: (String, String)) = {data._1}{data._2} } + +private[ui] class EnvironmentTab( + parent: SparkUI, + store: AppStatusStore) extends SparkUITab(parent, "environment") { + attachPage(new EnvironmentPage(this, parent.conf, store)) +} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala deleted file mode 100644 index 61b12aaa32bb6..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.env - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.ui._ - -private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { - val listener = parent.environmentListener - val conf = parent.conf - attachPage(new EnvironmentPage(this)) -} - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the EnvironmentTab - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class EnvironmentListener extends SparkListener { - var sparkUser: Option[String] = None - var jvmInformation = Seq[(String, String)]() - var sparkProperties = Seq[(String, String)]() - var systemProperties = Seq[(String, String)]() - var classpathEntries = Seq[(String, String)]() - - override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { - sparkUser = Some(event.sparkUser) - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - jvmInformation = environmentDetails("JVM Information") - sparkProperties = environmentDetails("Spark Properties") - systemProperties = environmentDetails("System Properties") - classpathEntries = environmentDetails("Classpath Entries") - } - } -} diff --git a/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json new file mode 100644 index 0000000000000..4ed053899ee6c --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/app_environment_expectation.json @@ -0,0 +1,281 @@ +{ + "runtime" : { + "javaVersion" : "1.8.0_92 (Oracle Corporation)", + "javaHome" : "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre", + "scalaVersion" : "version 2.11.8" + }, + "sparkProperties" : [ + [ "spark.blacklist.task.maxTaskAttemptsPerExecutor", "3" ], + [ "spark.blacklist.enabled", "TRUE" ], + [ "spark.driver.host", "172.22.0.167" ], + [ "spark.blacklist.task.maxTaskAttemptsPerNode", "3" ], + [ "spark.eventLog.enabled", "TRUE" ], + [ "spark.driver.port", "51459" ], + [ "spark.repl.class.uri", "spark://172.22.0.167:51459/classes" ], + [ "spark.jars", "" ], + [ "spark.repl.class.outputDir", "/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed" ], + [ "spark.app.name", "Spark shell" ], + [ "spark.blacklist.stage.maxFailedExecutorsPerNode", "3" ], + [ "spark.scheduler.mode", "FIFO" ], + [ "spark.eventLog.overwrite", "TRUE" ], + [ "spark.blacklist.stage.maxFailedTasksPerExecutor", "3" ], + [ "spark.executor.id", "driver" ], + [ "spark.blacklist.application.maxFailedExecutorsPerNode", "2" ], + [ "spark.submit.deployMode", "client" ], + [ "spark.master", "local-cluster[4,4,1024]" ], + [ "spark.home", "/Users/Jose/IdeaProjects/spark" ], + [ "spark.eventLog.dir", "/Users/jose/logs" ], + [ "spark.sql.catalogImplementation", "in-memory" ], + [ "spark.eventLog.compress", "FALSE" ], + [ "spark.blacklist.application.maxFailedTasksPerExecutor", "1" ], + [ "spark.blacklist.timeout", "1000000" ], + [ "spark.app.id", "app-20161116163331-0000" ], + [ "spark.task.maxFailures", "4" ] + ], + "systemProperties" : [ + [ "java.io.tmpdir", "/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/" ], + [ "line.separator", "\n" ], + [ "path.separator", ":" ], + [ "sun.management.compiler", "HotSpot 64-Bit Tiered Compilers" ], + [ "SPARK_SUBMIT", "true" ], + [ "sun.cpu.endian", "little" ], + [ "java.specification.version", "1.8" ], + [ "java.vm.specification.name", "Java Virtual Machine Specification" ], + [ "java.vendor", "Oracle Corporation" ], + [ "java.vm.specification.version", "1.8" ], + [ "user.home", "/Users/Jose" ], + [ "file.encoding.pkg", "sun.io" ], + [ "sun.nio.ch.bugLevel", "" ], + [ "ftp.nonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "sun.arch.data.model", "64" ], + [ "sun.boot.library.path", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib" ], + [ "user.dir", "/Users/Jose/IdeaProjects/spark" ], + [ "java.library.path", "/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:." ], + [ "sun.cpu.isalist", "" ], + [ "os.arch", "x86_64" ], + [ "java.vm.version", "25.92-b14" ], + [ "java.endorsed.dirs", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed" ], + [ "java.runtime.version", "1.8.0_92-b14" ], + [ "java.vm.info", "mixed mode" ], + [ "java.ext.dirs", "/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java" ], + [ "java.runtime.name", "Java(TM) SE Runtime Environment" ], + [ "file.separator", "/" ], + [ "io.netty.maxDirectMemory", "0" ], + [ "java.class.version", "52.0" ], + [ "scala.usejavacp", "true" ], + [ "java.specification.name", "Java Platform API Specification" ], + [ "sun.boot.class.path", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes" ], + [ "file.encoding", "UTF-8" ], + [ "user.timezone", "America/Chicago" ], + [ "java.specification.vendor", "Oracle Corporation" ], + [ "sun.java.launcher", "SUN_STANDARD" ], + [ "os.version", "10.11.6" ], + [ "sun.os.patch.level", "unknown" ], + [ "gopherProxySet", "false" ], + [ "java.vm.specification.vendor", "Oracle Corporation" ], + [ "user.country", "US" ], + [ "sun.jnu.encoding", "UTF-8" ], + [ "http.nonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "user.language", "en" ], + [ "socksNonProxyHosts", "local|*.local|169.254/16|*.169.254/16" ], + [ "java.vendor.url", "http://java.oracle.com/" ], + [ "java.awt.printerjob", "sun.lwawt.macosx.CPrinterJob" ], + [ "java.awt.graphicsenv", "sun.awt.CGraphicsEnvironment" ], + [ "awt.toolkit", "sun.lwawt.macosx.LWCToolkit" ], + [ "os.name", "Mac OS X" ], + [ "java.vm.vendor", "Oracle Corporation" ], + [ "java.vendor.url.bug", "http://bugreport.sun.com/bugreport/" ], + [ "user.name", "jose" ], + [ "java.vm.name", "Java HotSpot(TM) 64-Bit Server VM" ], + [ "sun.java.command", "org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala" ], + [ "java.home", "/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre" ], + [ "java.version", "1.8.0_92" ], + [ "sun.io.unicode.encoding", "UnicodeBig" ] + ], + "classpathEntries" : [ + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/conf/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/core/target/jars/*", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar", "System Classpath" ], + [ "/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar", "System Classpath" ] + ] +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 010a8dd004d4f..6a1abceaeb63c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -158,7 +158,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", - "executor memory usage" -> "applications/app-20161116163331-0000/executors" + "executor memory usage" -> "applications/app-20161116163331-0000/executors", + + "app environment" -> "applications/app-20161116163331-0000/environment" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 7ac1ce19f8ddf..867d35f231dc0 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -56,6 +56,46 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(testDir) } + test("environment info") { + val listener = new AppStatusListener(store, conf, true) + + val details = Map( + "JVM Information" -> Seq( + "Java Version" -> sys.props("java.version"), + "Java Home" -> sys.props("java.home"), + "Scala Version" -> scala.util.Properties.versionString + ), + "Spark Properties" -> Seq( + "spark.conf.1" -> "1", + "spark.conf.2" -> "2" + ), + "System Properties" -> Seq( + "sys.prop.1" -> "1", + "sys.prop.2" -> "2" + ), + "Classpath Entries" -> Seq( + "/jar1" -> "System", + "/jar2" -> "User" + ) + ) + + listener.onEnvironmentUpdate(SparkListenerEnvironmentUpdate(details)) + + val appEnvKey = classOf[ApplicationEnvironmentInfoWrapper].getName() + check[ApplicationEnvironmentInfoWrapper](appEnvKey) { env => + val info = env.info + + val runtimeInfo = Map(details("JVM Information"): _*) + assert(info.runtime.javaVersion == runtimeInfo("Java Version")) + assert(info.runtime.javaHome == runtimeInfo("Java Home")) + assert(info.runtime.scalaVersion == runtimeInfo("Scala Version")) + + assert(info.sparkProperties === details("Spark Properties")) + assert(info.systemProperties === details("System Properties")) + assert(info.classpathEntries === details("Classpath Entries")) + } + } + test("scheduler events") { val listener = new AppStatusListener(store, conf, true) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 99cac34c85ebc..62930e2e3a931 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,6 +39,7 @@ object MimaExcludes { // SPARK-18085: Better History Server scalability for many / large applications ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 3da3d76352cc471252a54088cc55208bb4ea5b3a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 7 Nov 2017 20:07:30 -0800 Subject: [PATCH 195/712] [SPARK-14516][ML][FOLLOW-UP] Move ClusteringEvaluatorSuite test data to data/mllib. ## What changes were proposed in this pull request? Move ```ClusteringEvaluatorSuite``` test data(iris) to data/mllib, to prevent from re-creating a new folder. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #19648 from yanboliang/spark-14516. --- .../iris.libsvm => data/mllib/iris_libsvm.txt | 0 .../evaluation/ClusteringEvaluatorSuite.scala | 30 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) rename mllib/src/test/resources/test-data/iris.libsvm => data/mllib/iris_libsvm.txt (100%) diff --git a/mllib/src/test/resources/test-data/iris.libsvm b/data/mllib/iris_libsvm.txt similarity index 100% rename from mllib/src/test/resources/test-data/iris.libsvm rename to data/mllib/iris_libsvm.txt diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index e60ebbd7c852d..677ce49a903ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.Dataset class ClusteringEvaluatorSuite @@ -31,6 +30,13 @@ class ClusteringEvaluatorSuite import testImplicits._ + @transient var irisDataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + } + test("params") { ParamsSuite.checkParams(new ClusteringEvaluator) } @@ -53,37 +59,23 @@ class ClusteringEvaluatorSuite 0.6564679231 */ test("squared euclidean Silhouette") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") - assert(evaluator.evaluate(iris) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } test("number of clusters must be greater than one") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) - .where($"label" === 0.0) + val singleClusterDataset = irisDataset.where($"label" === 0.0) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") val e = intercept[AssertionError]{ - evaluator.evaluate(iris) + evaluator.evaluate(singleClusterDataset) } assert(e.getMessage.contains("Number of clusters must be greater than one")) } } - -object ClusteringEvaluatorSuite { - def irisDataset(spark: SparkSession): DataFrame = { - - val irisPath = Thread.currentThread() - .getContextClassLoader - .getResource("test-data/iris.libsvm") - .toString - - spark.read.format("libsvm").load(irisPath) - } -} From 2ca5aae47a25dc6bc9e333fb592025ff14824501 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 7 Nov 2017 21:02:14 -0800 Subject: [PATCH 196/712] [SPARK-22281][SPARKR] Handle R method breaking signature changes ## What changes were proposed in this pull request? This is to fix the code for the latest R changes in R-devel, when running CRAN check ``` checking for code/documentation mismatches ... WARNING Codoc mismatches from documentation object 'attach': attach Code: function(what, pos = 2L, name = deparse(substitute(what), backtick = FALSE), warn.conflicts = TRUE) Docs: function(what, pos = 2L, name = deparse(substitute(what)), warn.conflicts = TRUE) Mismatches in argument default values: Name: 'name' Code: deparse(substitute(what), backtick = FALSE) Docs: deparse(substitute(what)) Codoc mismatches from documentation object 'glm': glm Code: function(formula, family = gaussian, data, weights, subset, na.action, start = NULL, etastart, mustart, offset, control = list(...), model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, singular.ok = TRUE, contrasts = NULL, ...) Docs: function(formula, family = gaussian, data, weights, subset, na.action, start = NULL, etastart, mustart, offset, control = list(...), model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, contrasts = NULL, ...) Argument names in code not in docs: singular.ok Mismatches in argument names: Position: 16 Code: singular.ok Docs: contrasts Position: 17 Code: contrasts Docs: ... ``` With attach, it's pulling in the function definition from base::attach. We need to disable that but we would still need a function signature for roxygen2 to build with. With glm it's pulling in the function definition (ie. "usage") from the stats::glm function. Since this is "compiled in" when we build the source package into the .Rd file, when it changes at runtime or in CRAN check it won't match the latest signature. The solution is not to pull in from stats::glm since there isn't much value in doing that (none of the param we actually use, the ones we do use we have explicitly documented them) Also with attach we are changing to call dynamically. ## How was this patch tested? Manually. - [x] check documentation output - yes - [x] check help `?attach` `?glm` - yes - [x] check on other platforms, r-hub, on r-devel etc.. Author: Felix Cheung Closes #19557 from felixcheung/rattachglmdocerror. --- R/pkg/R/DataFrame.R | 11 +++++++---- R/pkg/R/generics.R | 10 ++++------ R/pkg/R/mllib_regression.R | 1 + R/run-tests.sh | 6 +++--- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index aaa3349d57506..763c8d2548580 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3236,7 +3236,7 @@ setMethod("as.data.frame", #' #' @family SparkDataFrame functions #' @rdname attach -#' @aliases attach,SparkDataFrame-method +#' @aliases attach attach,SparkDataFrame-method #' @param what (SparkDataFrame) The SparkDataFrame to attach #' @param pos (integer) Specify position in search() where to attach. #' @param name (character) Name to use for the attached SparkDataFrame. Names @@ -3252,9 +3252,12 @@ setMethod("as.data.frame", #' @note attach since 1.6.0 setMethod("attach", signature(what = "SparkDataFrame"), - function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { - newEnv <- assignNewEnv(what) - attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) + function(what, pos = 2L, name = deparse(substitute(what), backtick = FALSE), + warn.conflicts = TRUE) { + args <- as.list(environment()) # capture all parameters - this must be the first line + newEnv <- assignNewEnv(args$what) + args$what <- newEnv + do.call(attach, args) }) #' Evaluate a R expression in an environment constructed from a SparkDataFrame diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4e427489f6860..8312d417b99d2 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -409,7 +409,8 @@ setGeneric("as.data.frame", standardGeneric("as.data.frame") }) -#' @rdname attach +# Do not document the generic because of signature changes across R versions +#' @noRd #' @export setGeneric("attach") @@ -1569,12 +1570,9 @@ setGeneric("year", function(x) { standardGeneric("year") }) #' @export setGeneric("fitted") -#' @param x,y For \code{glm}: logical values indicating whether the response vector -#' and model matrix used in the fitting process should be returned as -#' components of the returned value. -#' @inheritParams stats::glm -#' @rdname glm +# Do not carry stats::glm usage and param here, and do not document the generic #' @export +#' @noRd setGeneric("glm") #' @param object a fitted ML model object. diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index f734a0865ec3b..545be5e1d89f0 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -210,6 +210,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm +#' @aliases glm #' @export #' @examples #' \dontrun{ diff --git a/R/run-tests.sh b/R/run-tests.sh index f38c86e3e6b1d..86bd8aad5f113 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -47,10 +47,10 @@ if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - # We have 2 NOTEs for RoxygenNote, attach(); and one in Jenkins only "No repository set" + # We have 2 NOTEs: for RoxygenNote and one in Jenkins only "No repository set" # For non-latest version branches, one WARNING for package version - if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3) && - ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) ]]; then + if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) && + ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 1) ]]; then cat $CRAN_CHECK_LOG_FILE echo -en "\033[31m" # Red echo "Had CRAN check errors; see logs." From 11eea1a4ce32c9018218d4dfc9f46b744eb82991 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Nov 2017 23:14:29 -0600 Subject: [PATCH 197/712] [SPARK-20646][CORE] Port executors page to new UI backend. The executors page is built on top of the REST API, so the page itself was easy to hook up to the new code. Some other pages depend on the `ExecutorListener` class that is being removed, though, so they needed to be modified to use data from the new store. Fortunately, all they seemed to need is the map of executor logs, so that was somewhat easy too. The executor timeline graph required adding some properties to the ExecutorSummary API type. Instead of following the previous code, which stored all the listener events in memory, the timeline is now created based on the data available from the API. I had to change some of the test golden files because the old code would return executors in "random" order (since it used a mutable Map instead of something that returns a sorted list), and the new code returns executors in id order. Tested with existing unit tests. Author: Marcelo Vanzin Closes #19678 from vanzin/SPARK-20646. --- .../spark/status/AppStatusListener.scala | 36 ++- .../apache/spark/status/AppStatusStore.scala | 18 +- .../org/apache/spark/status/LiveEntity.scala | 9 +- .../api/v1/AllExecutorListResource.scala | 15 +- .../status/api/v1/ExecutorListResource.scala | 14 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../scala/org/apache/spark/ui/SparkUI.scala | 13 +- .../ui/exec/ExecutorThreadDumpPage.scala | 9 +- .../apache/spark/ui/exec/ExecutorsPage.scala | 154 ------------- .../apache/spark/ui/exec/ExecutorsTab.scala | 206 +++--------------- .../apache/spark/ui/jobs/AllJobsPage.scala | 60 +++-- .../apache/spark/ui/jobs/ExecutorTable.scala | 11 +- .../org/apache/spark/ui/jobs/JobPage.scala | 59 +++-- .../org/apache/spark/ui/jobs/JobsTab.scala | 3 +- .../org/apache/spark/ui/jobs/StagePage.scala | 18 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 8 +- .../executor_list_json_expectation.json | 1 + .../executor_memory_usage_expectation.json | 127 +++++------ ...xecutor_node_blacklisting_expectation.json | 117 +++++----- ...acklisting_unblacklisting_expectation.json | 89 ++++---- .../org/apache/spark/ui/StagePageSuite.scala | 11 +- project/MimaExcludes.scala | 1 + 22 files changed, 354 insertions(+), 628 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 424a1159a875c..0469c871362c0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -103,6 +103,9 @@ private[spark] class AppStatusListener( details.getOrElse("System Properties", Nil), details.getOrElse("Classpath Entries", Nil)) + coresPerTask = envInfo.sparkProperties.toMap.get("spark.task.cpus").map(_.toInt) + .getOrElse(coresPerTask) + kvstore.write(new ApplicationEnvironmentInfoWrapper(envInfo)) } @@ -132,7 +135,7 @@ private[spark] class AppStatusListener( override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { // This needs to be an update in case an executor re-registers after the driver has // marked it as "dead". - val exec = getOrCreateExecutor(event.executorId) + val exec = getOrCreateExecutor(event.executorId, event.time) exec.host = event.executorInfo.executorHost exec.isActive = true exec.totalCores = event.executorInfo.totalCores @@ -144,6 +147,8 @@ private[spark] class AppStatusListener( override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => exec.isActive = false + exec.removeTime = new Date(event.time) + exec.removeReason = event.reason update(exec, System.nanoTime()) } } @@ -357,18 +362,25 @@ private[spark] class AppStatusListener( } liveExecutors.get(event.taskInfo.executorId).foreach { exec => - if (event.taskMetrics != null) { - val readMetrics = event.taskMetrics.shuffleReadMetrics - exec.totalGcTime += event.taskMetrics.jvmGCTime - exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead - exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead - exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten - } - exec.activeTasks -= 1 exec.completedTasks += completedDelta exec.failedTasks += failedDelta exec.totalDuration += event.taskInfo.duration + + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + if (event.reason != Resubmitted) { + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + } + maybeUpdate(exec, now) } } @@ -409,7 +421,7 @@ private[spark] class AppStatusListener( override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. - val exec = getOrCreateExecutor(event.blockManagerId.executorId) + val exec = getOrCreateExecutor(event.blockManagerId.executorId, event.time) exec.hostPort = event.blockManagerId.hostPort event.maxOnHeapMem.foreach { _ => exec.totalOnHeap = event.maxOnHeapMem.get @@ -561,8 +573,8 @@ private[spark] class AppStatusListener( } } - private def getOrCreateExecutor(executorId: String): LiveExecutor = { - liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + private def getOrCreateExecutor(executorId: String, addTime: Long): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId, addTime)) } private def getOrCreateStage(info: StageInfo): LiveStage = { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index d6b5d2661e8ee..334407829f9fe 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -56,8 +56,22 @@ private[spark] class AppStatusStore(store: KVStore) { } def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { - store.view(classOf[ExecutorSummaryWrapper]).index("active").reverse().first(true) - .last(true).asScala.map(_.info).toSeq + val base = store.view(classOf[ExecutorSummaryWrapper]) + val filtered = if (activeOnly) { + base.index("active").reverse().first(true).last(true) + } else { + base + } + filtered.asScala.map(_.info).toSeq + } + + def executorSummary(executorId: String): Option[v1.ExecutorSummary] = { + try { + Some(store.read(classOf[ExecutorSummaryWrapper], executorId).info) + } catch { + case _: NoSuchElementException => + None + } } def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 041dfe1ef915e..8c48020e246b4 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -212,13 +212,17 @@ private class LiveTask( } -private class LiveExecutor(val executorId: String) extends LiveEntity { +private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveEntity { var hostPort: String = null var host: String = null var isActive = true var totalCores = 0 + val addTime = new Date(_addTime) + var removeTime: Date = null + var removeReason: String = null + var rddBlocks = 0 var memoryUsed = 0L var diskUsed = 0L @@ -276,6 +280,9 @@ private class LiveExecutor(val executorId: String) extends LiveEntity { totalShuffleWrite, isBlacklisted, maxMemory, + addTime, + Option(removeTime), + Option(removeReason), executorLogs, memoryMetrics) new ExecutorSummaryWrapper(info) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala index eb5cc1b9a3bd0..5522f4cebd773 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala @@ -20,22 +20,11 @@ import javax.ws.rs.{GET, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.exec.ExecutorsPage @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllExecutorListResource(ui: SparkUI) { @GET - def executorList(): Seq[ExecutorSummary] = { - val listener = ui.executorsListener - listener.synchronized { - // The follow codes should be protected by `listener` to make sure no executors will be - // removed before we query their status. See SPARK-12784. - (0 until listener.activeStorageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = true) - } ++ (0 until listener.deadStorageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = false) - } - } - } + def executorList(): Seq[ExecutorSummary] = ui.store.executorList(false) + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 2f3b5e984002a..975101c33c59c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -20,21 +20,11 @@ import javax.ws.rs.{GET, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.exec.ExecutorsPage @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class ExecutorListResource(ui: SparkUI) { @GET - def executorList(): Seq[ExecutorSummary] = { - val listener = ui.executorsListener - listener.synchronized { - // The follow codes should be protected by `listener` to make sure no executors will be - // removed before we query their status. See SPARK-12784. - val storageStatusList = listener.activeStorageStatusList - (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId, isActive = true) - } - } - } + def executorList(): Seq[ExecutorSummary] = ui.store.executorList(true) + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index bff6f90823f40..b338b1f3fd073 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -85,6 +85,9 @@ class ExecutorSummary private[spark]( val totalShuffleWrite: Long, val isBlacklisted: Boolean, val maxMemory: Long, + val addTime: Date, + val removeTime: Option[Date], + val removeReason: Option[String], val executorLogs: Map[String, String], val memoryMetrics: Option[MemoryMetrics]) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 43b57a1630aa9..79d40b6a90c3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -29,7 +29,7 @@ import org.apache.spark.status.api.v1._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab -import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} +import org.apache.spark.ui.exec.ExecutorsTab import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener import org.apache.spark.ui.storage.{StorageListener, StorageTab} @@ -44,7 +44,6 @@ private[spark] class SparkUI private ( val conf: SparkConf, securityManager: SecurityManager, val storageStatusListener: StorageStatusListener, - val executorsListener: ExecutorsListener, val jobProgressListener: JobProgressListener, val storageListener: StorageListener, val operationGraphListener: RDDOperationGraphListener, @@ -68,7 +67,7 @@ private[spark] class SparkUI private ( def initialize() { val jobsTab = new JobsTab(this) attachTab(jobsTab) - val stagesTab = new StagesTab(this) + val stagesTab = new StagesTab(this, store) attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this, store)) @@ -189,18 +188,16 @@ private[spark] object SparkUI { } val storageStatusListener = new StorageStatusListener(conf) - val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) addListenerFn(storageStatusListener) - addListenerFn(executorsListener) addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, storageStatusListener, executorsListener, - jobProgressListener, storageListener, operationGraphListener, appName, basePath, - lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, storageStatusListener, jobProgressListener, + storageListener, operationGraphListener, appName, basePath, lastUpdateTime, startTime, + appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7b211ea5199c3..f4686ea3cf91f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -22,11 +22,12 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.SparkContext +import org.apache.spark.ui.{SparkUITab, UIUtils, WebUIPage} -private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { - - private val sc = parent.sc +private[ui] class ExecutorThreadDumpPage( + parent: SparkUITab, + sc: Option[SparkContext]) extends WebUIPage("threadDump") { // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala deleted file mode 100644 index 7b2767f0be3cd..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.exec - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} -import org.apache.spark.ui.{UIUtils, WebUIPage} - -// This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive -private[ui] case class ExecutorSummaryInfo( - id: String, - hostPort: String, - rddBlocks: Int, - memoryUsed: Long, - diskUsed: Long, - activeTasks: Int, - failedTasks: Int, - completedTasks: Int, - totalTasks: Int, - totalDuration: Long, - totalInputBytes: Long, - totalShuffleRead: Long, - totalShuffleWrite: Long, - isBlacklisted: Int, - maxOnHeapMem: Long, - maxOffHeapMem: Long, - executorLogs: Map[String, String]) - - -private[ui] class ExecutorsPage( - parent: ExecutorsTab, - threadDumpEnabled: Boolean) - extends WebUIPage("") { - - def render(request: HttpServletRequest): Seq[Node] = { - val content = -
    - { -
    - - - Show Additional Metrics - - -
    ++ -
    ++ - ++ - ++ - - } -
    - - UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) - } -} - -private[spark] object ExecutorsPage { - private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + - "storage of data like RDD partitions cached in memory." - private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + - "storage of data like RDD partitions cached in memory." - - /** Represent an executor's info as a map given a storage status index */ - def getExecInfo( - listener: ExecutorsListener, - statusId: Int, - isActive: Boolean): ExecutorSummary = { - val status = if (isActive) { - listener.activeStorageStatusList(statusId) - } else { - listener.deadStorageStatusList(statusId) - } - val execId = status.blockManagerId.executorId - val hostPort = status.blockManagerId.hostPort - val rddBlocks = status.numBlocks - val memUsed = status.memUsed - val maxMem = status.maxMem - val memoryMetrics = for { - onHeapUsed <- status.onHeapMemUsed - offHeapUsed <- status.offHeapMemUsed - maxOnHeap <- status.maxOnHeapMem - maxOffHeap <- status.maxOffHeapMem - } yield { - new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) - } - - - val diskUsed = status.diskUsed - val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) - - new ExecutorSummary( - execId, - hostPort, - isActive, - rddBlocks, - memUsed, - diskUsed, - taskSummary.totalCores, - taskSummary.tasksMax, - taskSummary.tasksActive, - taskSummary.tasksFailed, - taskSummary.tasksComplete, - taskSummary.tasksActive + taskSummary.tasksFailed + taskSummary.tasksComplete, - taskSummary.duration, - taskSummary.jvmGCTime, - taskSummary.inputBytes, - taskSummary.shuffleRead, - taskSummary.shuffleWrite, - taskSummary.isBlacklisted, - maxMem, - taskSummary.executorLogs, - memoryMetrics - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 64a1a292a3840..843486f4a70d2 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -17,192 +17,44 @@ package org.apache.spark.ui.exec -import scala.collection.mutable.{LinkedHashMap, ListBuffer} +import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Resubmitted, SparkConf, SparkContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage.{StorageStatus, StorageStatusListener} -import org.apache.spark.ui.{SparkUI, SparkUITab} +import scala.xml.Node -private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - val listener = parent.executorsListener - val sc = parent.sc - val threadDumpEnabled = - sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - - attachPage(new ExecutorsPage(this, threadDumpEnabled)) - if (threadDumpEnabled) { - attachPage(new ExecutorThreadDumpPage(this)) - } -} - -private[ui] case class ExecutorTaskSummary( - var executorId: String, - var totalCores: Int = 0, - var tasksMax: Int = 0, - var tasksActive: Int = 0, - var tasksFailed: Int = 0, - var tasksComplete: Int = 0, - var duration: Long = 0L, - var jvmGCTime: Long = 0L, - var inputBytes: Long = 0L, - var inputRecords: Long = 0L, - var outputBytes: Long = 0L, - var outputRecords: Long = 0L, - var shuffleRead: Long = 0L, - var shuffleWrite: Long = 0L, - var executorLogs: Map[String, String] = Map.empty, - var isAlive: Boolean = true, - var isBlacklisted: Boolean = false -) - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the ExecutorsTab - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) - extends SparkListener { - val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() - var executorEvents = new ListBuffer[SparkListenerEvent]() - - private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) - private val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) - - def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils, WebUIPage} - def deadStorageStatusList: Seq[StorageStatus] = storageStatusListener.deadStorageStatusList - - override def onExecutorAdded( - executorAdded: SparkListenerExecutorAdded): Unit = synchronized { - val eid = executorAdded.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.executorLogs = executorAdded.executorInfo.logUrlMap - taskSummary.totalCores = executorAdded.executorInfo.totalCores - taskSummary.tasksMax = taskSummary.totalCores / conf.getInt("spark.task.cpus", 1) - executorEvents += executorAdded - if (executorEvents.size > maxTimelineExecutors) { - executorEvents.remove(0) - } - - val deadExecutors = executorToTaskSummary.filter(e => !e._2.isAlive) - if (deadExecutors.size > retainedDeadExecutors) { - val head = deadExecutors.head - executorToTaskSummary.remove(head._1) - } - } - - override def onExecutorRemoved( - executorRemoved: SparkListenerExecutorRemoved): Unit = synchronized { - executorEvents += executorRemoved - if (executorEvents.size > maxTimelineExecutors) { - executorEvents.remove(0) - } - executorToTaskSummary.get(executorRemoved.executorId).foreach(e => e.isAlive = false) - } - - override def onApplicationStart( - applicationStart: SparkListenerApplicationStart): Unit = { - applicationStart.driverLogs.foreach { logs => - val storageStatus = activeStorageStatusList.find { s => - s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER || - s.blockManagerId.executorId == SparkContext.DRIVER_IDENTIFIER - } - storageStatus.foreach { s => - val eid = s.blockManagerId.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.executorLogs = logs.toMap - } - } - } - - override def onTaskStart( - taskStart: SparkListenerTaskStart): Unit = synchronized { - val eid = taskStart.taskInfo.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - taskSummary.tasksActive += 1 - } +private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - override def onTaskEnd( - taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val info = taskEnd.taskInfo - if (info != null) { - val eid = info.executorId - val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - // Note: For resubmitted tasks, we continue to use the metrics that belong to the - // first attempt of this task. This may not be 100% accurate because the first attempt - // could have failed half-way through. The correct fix would be to keep track of the - // metrics added by each attempt, but this is much more complicated. - if (taskEnd.reason == Resubmitted) { - return - } - if (info.successful) { - taskSummary.tasksComplete += 1 - } else { - taskSummary.tasksFailed += 1 - } - if (taskSummary.tasksActive >= 1) { - taskSummary.tasksActive -= 1 - } - taskSummary.duration += info.duration + init() - // Update shuffle read/write - val metrics = taskEnd.taskMetrics - if (metrics != null) { - taskSummary.inputBytes += metrics.inputMetrics.bytesRead - taskSummary.inputRecords += metrics.inputMetrics.recordsRead - taskSummary.outputBytes += metrics.outputMetrics.bytesWritten - taskSummary.outputRecords += metrics.outputMetrics.recordsWritten + private def init(): Unit = { + val threadDumpEnabled = + parent.sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - taskSummary.shuffleRead += metrics.shuffleReadMetrics.remoteBytesRead - taskSummary.shuffleWrite += metrics.shuffleWriteMetrics.bytesWritten - taskSummary.jvmGCTime += metrics.jvmGCTime - } + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this, parent.sc)) } } - private def updateExecutorBlacklist( - eid: String, - isBlacklisted: Boolean): Unit = { - val execTaskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) - execTaskSummary.isBlacklisted = isBlacklisted - } - - override def onExecutorBlacklisted( - executorBlacklisted: SparkListenerExecutorBlacklisted) - : Unit = synchronized { - updateExecutorBlacklist(executorBlacklisted.executorId, true) - } - - override def onExecutorUnblacklisted( - executorUnblacklisted: SparkListenerExecutorUnblacklisted) - : Unit = synchronized { - updateExecutorBlacklist(executorUnblacklisted.executorId, false) - } - - override def onNodeBlacklisted( - nodeBlacklisted: SparkListenerNodeBlacklisted) - : Unit = synchronized { - // Implicitly blacklist every executor associated with this node, and show this in the UI. - activeStorageStatusList.foreach { status => - if (status.blockManagerId.host == nodeBlacklisted.hostId) { - updateExecutorBlacklist(status.blockManagerId.executorId, true) - } - } - } +} - override def onNodeUnblacklisted( - nodeUnblacklisted: SparkListenerNodeUnblacklisted) - : Unit = synchronized { - // Implicitly unblacklist every executor associated with this node, regardless of how - // they may have been blacklisted initially (either explicitly through executor blacklisting - // or implicitly through node blacklisting). Show this in the UI. - activeStorageStatusList.foreach { status => - if (status.blockManagerId.host == nodeUnblacklisted.hostId) { - updateExecutorBlacklist(status.blockManagerId.executorId, false) - } - } +private[ui] class ExecutorsPage( + parent: SparkUITab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { + + def render(request: HttpServletRequest): Seq[Node] = { + val content = +
    + { +
    ++ + ++ + ++ + + } +
    + + UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a7f2caafe04b8..a647a1173a8cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.ExecutorSummary import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData} import org.apache.spark.util.Utils @@ -123,55 +124,53 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } } - private def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): + private def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() - executorUIDatas.foreach { - case a: SparkListenerExecutorAdded => - val addedEvent = - s""" - |{ - | 'className': 'executor added', - | 'group': 'executors', - | 'start': new Date(${a.time}), - | 'content': '
    Executor ${a.executorId} added
    ' - |} - """.stripMargin - events += addedEvent - case e: SparkListenerExecutorRemoved => + executors.foreach { e => + val addedEvent = + s""" + |{ + | 'className': 'executor added', + | 'group': 'executors', + | 'start': new Date(${e.addTime.getTime()}), + | 'content': '
    Executor ${e.id} added
    ' + |} + """.stripMargin + events += addedEvent + + e.removeTime.foreach { removeTime => val removedEvent = s""" |{ | 'className': 'executor removed', | 'group': 'executors', - | 'start': new Date(${e.time}), + | 'start': new Date(${removeTime.getTime()}), | 'content': '
    ' + + | 'Removed at ${UIUtils.formatDate(removeTime)}' + | '${ - if (e.reason != null) { - s"""
    Reason: ${e.reason.replace("\n", " ")}""" - } else { - "" - } + e.removeReason.map { reason => + s"""
    Reason: ${reason.replace("\n", " ")}""" + }.getOrElse("") }"' + - | 'data-html="true">Executor ${e.executorId} removed
    ' + | 'data-html="true">Executor ${e.id} removed' |} """.stripMargin events += removedEvent - + } } events.toSeq } private def makeTimeline( jobs: Seq[JobUIData], - executors: Seq[SparkListenerEvent], + executors: Seq[ExecutorSummary], startTime: Long): Seq[Node] = { val jobEventJsonAsStrSeq = makeJobEvent(jobs) @@ -360,9 +359,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { var content = summary - val executorListener = parent.executorListener content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, - executorListener.executorEvents, startTime) + parent.parent.store.executorList(false), startTime) if (shouldShowActiveJobs) { content ++=

    Active Jobs ({activeJobs.size})

    ++ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 382a6f979f2e6..07a41d195a191 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -20,12 +20,17 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.xml.{Node, Unparsed} +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { +private[ui] class ExecutorTable( + stageId: Int, + stageAttemptId: Int, + parent: StagesTab, + store: AppStatusStore) { private val listener = parent.progressListener def toNodeSeq: Seq[Node] = { @@ -123,9 +128,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
    {k}
    { - val logs = parent.executorsListener.executorToTaskSummary.get(k) - .map(_.executorLogs).getOrElse(Map.empty) - logs.map { + store.executorSummary(k).map(_.executorLogs).getOrElse(Map.empty).map { case (logName, logUrl) => } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 9fb011a049b7e..7ed01646f3621 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.ExecutorSummary import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} /** Page showing statistics and stage list for a given job */ @@ -92,55 +93,52 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } } - def makeExecutorEvent(executorUIDatas: Seq[SparkListenerEvent]): Seq[String] = { + def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() - executorUIDatas.foreach { - case a: SparkListenerExecutorAdded => - val addedEvent = - s""" - |{ - | 'className': 'executor added', - | 'group': 'executors', - | 'start': new Date(${a.time}), - | 'content': '
    Executor ${a.executorId} added
    ' - |} - """.stripMargin - events += addedEvent + executors.foreach { e => + val addedEvent = + s""" + |{ + | 'className': 'executor added', + | 'group': 'executors', + | 'start': new Date(${e.addTime.getTime()}), + | 'content': '
    Executor ${e.id} added
    ' + |} + """.stripMargin + events += addedEvent - case e: SparkListenerExecutorRemoved => + e.removeTime.foreach { removeTime => val removedEvent = s""" |{ | 'className': 'executor removed', | 'group': 'executors', - | 'start': new Date(${e.time}), + | 'start': new Date(${removeTime.getTime()}), | 'content': '
    ' + + | 'Removed at ${UIUtils.formatDate(removeTime)}' + | '${ - if (e.reason != null) { - s"""
    Reason: ${e.reason.replace("\n", " ")}""" - } else { - "" - } + e.removeReason.map { reason => + s"""
    Reason: ${reason.replace("\n", " ")}""" + }.getOrElse("") }"' + - | 'data-html="true">Executor ${e.executorId} removed
    ' + | 'data-html="true">Executor ${e.id} removed
    ' |} """.stripMargin events += removedEvent - + } } events.toSeq } private def makeTimeline( stages: Seq[StageInfo], - executors: Seq[SparkListenerEvent], + executors: Seq[ExecutorSummary], appStartTime: Long): Seq[Node] = { val stageEventJsonAsStrSeq = makeStageEvent(stages) @@ -322,11 +320,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { var content = summary val appStartTime = listener.startTime - val executorListener = parent.executorListener val operationGraphListener = parent.operationGraphListener content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, - executorListener.executorEvents, appStartTime) + parent.parent.store.executorList(false), appStartTime) content ++= UIUtils.showDagVizForJob( jobId, operationGraphListener.getOperationGraphForJob(jobId)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index cc173381879a6..81ffe04aca49a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -23,11 +23,10 @@ import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { +private[ui] class JobsTab(val parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled val jobProgresslistener = parent.jobProgressListener - val executorListener = parent.executorsListener val operationGraphListener = parent.operationGraphListener def isFairScheduler: Boolean = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4d80308eb0a6d..3151b8d554658 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -29,18 +29,17 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui._ -import org.apache.spark.ui.exec.ExecutorsListener import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import StagePage._ private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener - private val executorsListener = parent.executorsListener private val TIMELINE_LEGEND = {
    @@ -304,7 +303,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { pageSize = taskPageSize, sortColumn = taskSortColumn, desc = taskSortDesc, - executorsListener = executorsListener + store = store ) (_taskTable, _taskTable.table(page)) } catch { @@ -563,7 +562,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { stripeRowsWithCss = false)) } - val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) + val executorTable = new ExecutorTable(stageId, stageAttemptId, parent, store) val maybeAccumulableTable: Seq[Node] = if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq.empty @@ -869,7 +868,7 @@ private[ui] class TaskDataSource( pageSize: Int, sortColumn: String, desc: Boolean, - executorsListener: ExecutorsListener) extends PagedDataSource[TaskTableRowData](pageSize) { + store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { import StagePage._ // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table @@ -1012,8 +1011,7 @@ private[ui] class TaskDataSource( None } - val logs = executorsListener.executorToTaskSummary.get(info.executorId) - .map(_.executorLogs).getOrElse(Map.empty) + val logs = store.executorSummary(info.executorId).map(_.executorLogs).getOrElse(Map.empty) new TaskTableRowData( info.index, info.taskId, @@ -1162,7 +1160,7 @@ private[ui] class TaskPagedTable( pageSize: Int, sortColumn: String, desc: Boolean, - executorsListener: ExecutorsListener) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskTableRowData] { override def tableId: String = "task-table" @@ -1188,7 +1186,7 @@ private[ui] class TaskPagedTable( pageSize, sortColumn, desc, - executorsListener) + store) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 0787ea6625903..65446f967ad76 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -20,20 +20,22 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ -private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { +private[ui] class StagesTab(val parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "stages") { + val sc = parent.sc val conf = parent.conf val killEnabled = parent.killEnabled val progressListener = parent.jobProgressListener val operationGraphListener = parent.operationGraphListener - val executorsListener = parent.executorsListener val lastUpdateTime = parent.lastUpdateTime attachPage(new AllStagesPage(this)) - attachPage(new StagePage(this)) + attachPage(new StagePage(this, store)) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR) diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 6b9f29e1a230e..942e6d8f04363 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -18,5 +18,6 @@ "totalShuffleWrite" : 13180, "isBlacklisted" : false, "maxMemory" : 278302556, + "addTime" : "2015-02-03T16:43:00.906GMT", "executorLogs" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index 0f94e3b255dbc..ed33c90dd39ba 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -1,6 +1,34 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.167:51487", + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:31.477GMT", + "executorLogs" : { }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,52 +36,57 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 4, - "completedTasks" : 0, - "totalTasks" : 4, - "totalDuration" : 2537, - "totalGCTime" : 88, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.320GMT", "executorLogs" : { - "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } -}, { - "id" : "driver", - "hostPort" : "172.22.0.167:51475", +} ,{ + "id" : "2", + "hostPort" : "172.22.0.167:51487", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 4, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, - "executorLogs" : { }, + "addTime" : "2016-11-16T22:33:35.393GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { "id" : "1", @@ -75,6 +108,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.443GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" @@ -105,44 +139,15 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.462GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 - } -}, { - "id" : "3", - "hostPort" : "172.22.0.167:51485", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 12, - "totalTasks" : 12, - "totalDuration" : 2453, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : true, - "maxMemory" : 908381388, - "executorLogs" : { - "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - }, - "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 0f94e3b255dbc..73519f1d9e2e4 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -1,6 +1,34 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.167:51487", + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:31.477GMT", + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,52 +36,57 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 4, - "completedTasks" : 0, - "totalTasks" : 4, - "totalDuration" : 2537, - "totalGCTime" : 88, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.320GMT", "executorLogs" : { - "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { - "id" : "driver", - "hostPort" : "172.22.0.167:51475", + "id" : "2", + "hostPort" : "172.22.0.167:51487", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 4, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, - "executorLogs" : { }, + "addTime" : "2016-11-16T22:33:35.393GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 524288000 } }, { "id" : "1", @@ -75,6 +108,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.443GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" @@ -105,6 +139,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : true, "maxMemory" : 908381388, + "addTime" : "2016-11-16T22:33:35.462GMT", "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" @@ -115,34 +150,4 @@ "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 } -}, { - "id" : "3", - "hostPort" : "172.22.0.167:51485", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 12, - "totalTasks" : 12, - "totalDuration" : 2453, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : true, - "maxMemory" : 908381388, - "executorLogs" : { - "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - }, - "memoryMetrics": { - "usedOnHeapStorageMemory": 0, - "usedOffHeapStorageMemory": 0, - "totalOnHeapStorageMemory": 384093388, - "totalOffHeapStorageMemory": 524288000 - } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 92e249c851116..6931fead3d2ff 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -1,6 +1,28 @@ [ { - "id" : "2", - "hostPort" : "172.22.0.111:64539", + "id" : "driver", + "hostPort" : "172.22.0.111:64527", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:38.836GMT", + "executorLogs" : { } +}, { + "id" : "3", + "hostPort" : "172.22.0.111:64543", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, @@ -8,41 +30,46 @@ "totalCores" : 4, "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 6, - "completedTasks" : 0, - "totalTasks" : 6, - "totalDuration" : 2792, - "totalGCTime" : 128, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3457, + "totalGCTime" : 72, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.711GMT", "executorLogs" : { - "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", - "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" } }, { - "id" : "driver", - "hostPort" : "172.22.0.111:64527", + "id" : "2", + "hostPort" : "172.22.0.111:64539", "isActive" : true, "rddBlocks" : 0, "memoryUsed" : 0, "diskUsed" : 0, - "totalCores" : 0, - "maxTasks" : 0, + "totalCores" : 4, + "maxTasks" : 4, "activeTasks" : 0, - "failedTasks" : 0, + "failedTasks" : 6, "completedTasks" : 0, - "totalTasks" : 0, - "totalDuration" : 0, - "totalGCTime" : 0, + "totalTasks" : 6, + "totalDuration" : 2792, + "totalGCTime" : 128, "totalInputBytes" : 0, "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, - "executorLogs" : { } + "addTime" : "2016-11-15T23:20:42.589GMT", + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + } }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -63,6 +90,7 @@ "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.629GMT", "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" @@ -87,32 +115,9 @@ "totalShuffleWrite" : 0, "isBlacklisted" : false, "maxMemory" : 384093388, + "addTime" : "2016-11-15T23:20:42.593GMT", "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" } -}, { - "id" : "3", - "hostPort" : "172.22.0.111:64543", - "isActive" : true, - "rddBlocks" : 0, - "memoryUsed" : 0, - "diskUsed" : 0, - "totalCores" : 4, - "maxTasks" : 4, - "activeTasks" : 0, - "failedTasks" : 0, - "completedTasks" : 4, - "totalTasks" : 4, - "totalDuration" : 3457, - "totalGCTime" : 72, - "totalInputBytes" : 0, - "totalShuffleRead" : 0, - "totalShuffleWrite" : 0, - "isBlacklisted" : false, - "maxMemory" : 384093388, - "executorLogs" : { - "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", - "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } } ] diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 499d47b13d702..1c51c148ae61b 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -22,13 +22,13 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.mockito.Matchers.anyString import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.exec.ExecutorsListener +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener @@ -55,20 +55,21 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * This also runs a dummy stage to populate the page with useful content. */ private def renderStagePage(conf: SparkConf): Seq[Node] = { + val store = mock(classOf[AppStatusStore]) + when(store.executorSummary(anyString())).thenReturn(None) + val jobListener = new JobProgressListener(conf) val graphListener = new RDDOperationGraphListener(conf) - val executorsListener = new ExecutorsListener(new StorageStatusListener(conf), conf) val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) when(tab.progressListener).thenReturn(jobListener) when(tab.operationGraphListener).thenReturn(graphListener) - when(tab.executorsListener).thenReturn(executorsListener) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab) + val page = new StagePage(tab, store) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 62930e2e3a931..0c31b2b4a9402 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,6 +40,7 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 51debf8b1f4d479bc7f81e2759ba28e526367d70 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 8 Nov 2017 10:24:40 +0000 Subject: [PATCH 198/712] [SPARK-14540][BUILD] Support Scala 2.12 closures and Java 8 lambdas in ClosureCleaner (step 0) ## What changes were proposed in this pull request? Preliminary changes to get ClosureCleaner to work with Scala 2.12. Makes many usages just work, but not all. This does _not_ resolve the JIRA. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19675 from srowen/SPARK-14540.0. --- .../apache/spark/util/ClosureCleaner.scala | 28 +++++++++++-------- .../spark/util/ClosureCleanerSuite2.scala | 10 +++++-- .../spark/graphx/lib/ShortestPaths.scala | 2 +- .../streaming/BasicOperationsSuite.scala | 4 +-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index dfece5dd0670b..40616421b5bca 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -38,12 +38,13 @@ private[spark] object ClosureCleaner extends Logging { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + if (resourceStream == null) { + null + } else { + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } } // Check whether a class represents a Scala closure @@ -81,11 +82,13 @@ private[spark] object ClosureCleaner extends Logging { val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { val cr = getClassReader(stack.pop()) - val set = Set.empty[Class[_]] - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack.push(cls) + if (cr != null) { + val set = Set.empty[Class[_]] + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack.push(cls) + } } } (seen - obj.getClass).toList @@ -366,7 +369,8 @@ private[spark] class ReturnStatementInClosureException private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - if (name.contains("apply")) { + // $anonfun$ covers Java 8 lambdas + if (name.contains("apply") || name.contains("$anonfun$")) { new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 934385fbcad1b..278fada83d78c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -117,9 +117,13 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri findTransitively: Boolean): Map[Class[_], Set[String]] = { val fields = new mutable.HashMap[Class[_], mutable.Set[String]] outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] } - ClosureCleaner.getClassReader(closure.getClass) - .accept(new FieldAccessFinder(fields, findTransitively), 0) - fields.mapValues(_.toSet).toMap + val cr = ClosureCleaner.getClassReader(closure.getClass) + if (cr == null) { + Map.empty + } else { + cr.accept(new FieldAccessFinder(fields, findTransitively), 0) + fields.mapValues(_.toSet).toMap + } } // Accessors for private methods diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 4cac633aed008..aff0b932e9429 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -25,7 +25,7 @@ import org.apache.spark.graphx._ * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. */ -object ShortestPaths { +object ShortestPaths extends Serializable { /** Stores a map from the vertex id of a landmark to the distance to that landmark. */ type SPMap = Map[VertexId, Int] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 6f62c7a88dc3c..0a764f61c0cd9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -596,8 +596,6 @@ class BasicOperationsSuite extends TestSuiteBase { ) val updateStateOperation = (s: DStream[String]) => { - class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable - // updateFunc clears a state when a StateObject is seen without new values twice in a row val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { val stateObj = state.getOrElse(new StateObject) @@ -817,3 +815,5 @@ class BasicOperationsSuite extends TestSuiteBase { } } } + +class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable From 87343e15566da870d1b8e49a78ef16e08ddfd406 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Nov 2017 12:17:52 +0100 Subject: [PATCH 199/712] [SPARK-22446][SQL][ML] Declare StringIndexerModel indexer udf as nondeterministic ## What changes were proposed in this pull request? UDFs that can cause runtime exception on invalid data are not safe to pushdown, because its behavior depends on its position in the query plan. Pushdown of it will risk to change its original behavior. The example reported in the JIRA and taken as test case shows this issue. We should declare UDFs that can cause runtime exception on invalid data as non-determinstic. This updates the document of `deterministic` property in `Expression` and states clearly an UDF that can cause runtime exception on some specific input, should be declared as non-determinstic. ## How was this patch tested? Added test. Manually test. Author: Liang-Chi Hsieh Closes #19662 from viirya/SPARK-22446. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../spark/ml/feature/StringIndexerSuite.scala | 17 ++++++++++++++ .../ml/feature/VectorAssemblerSuite.scala | 23 ++++++++++++++++++- .../sql/catalyst/expressions/Expression.scala | 1 + 5 files changed, 42 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2679ec310c470..1cdcdfcaeab78 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -261,7 +261,7 @@ class StringIndexerModel ( s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") } } - } + }.asNondeterministic() filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 73f27d1a423d9..b373ae921ed38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -97,7 +97,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) - } + }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 027b1fbc6657c..775a04d3df050 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -314,4 +314,21 @@ class StringIndexerSuite idx += 1 } } + + test("SPARK-22446: StringIndexerModel's indexer UDF should not apply on filtered data") { + val df = List( + ("A", "London", "StrA"), + ("B", "Bristol", null), + ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT") + + val dfNoBristol = df.filter($"CONTENT".isNotNull) + + val model = new StringIndexer() + .setInputCol("CITY") + .setOutputCol("CITYIndexed") + .fit(dfNoBristol) + + val dfWithIndex = model.transform(dfNoBristol) + assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 6aef1c6837025..eca065f7e775d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -126,4 +126,25 @@ class VectorAssemblerSuite .setOutputCol("myOutputCol") testDefaultReadWrite(t) } + + test("SPARK-22446: VectorAssembler's UDF should not apply on filtered data") { + val df = Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L), + (0, 1.0, null, "b", null, 20L) + ).toDF("id", "x", "y", "name", "z", "n") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "z", "n")) + .setOutputCol("features") + + val filteredDF = df.filter($"y".isNotNull) + + val vectorUDF = udf { vector: Vector => + vector.numActives + } + + assert(assembler.transform(filteredDF).select("features") + .filter(vectorUDF($"features") > 1) + .count() == 1) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0e75ac88dc2b8..a3b722a47d688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -75,6 +75,7 @@ abstract class Expression extends TreeNode[Expression] { * - it relies on some mutable internal state, or * - it relies on some implicit input that is not part of the children expression list. * - it has non-deterministic child or children. + * - it assumes the input satisfies some certain condition via the child operator. * * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. From 6447d7bc1de4ab1d99a8dfcd3fea07f5a2da363d Mon Sep 17 00:00:00 2001 From: "Li, YanKit | Wilson | RIT" Date: Wed, 8 Nov 2017 17:55:21 +0000 Subject: [PATCH 200/712] [SPARK-22133][DOCS] Documentation for Mesos Reject Offer Configurations ## What changes were proposed in this pull request? Documentation about Mesos Reject Offer Configurations ## Related PR https://github.com/apache/spark/pull/19510 for `spark.mem.max` Author: Li, YanKit | Wilson | RIT Closes #19555 from windkit/spark_22133. --- docs/running-on-mesos.md | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b7e3e6473c338..7a443ffddc5f0 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -203,7 +203,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the -Spark application will reserve all resources offered to it by Mesos, +Spark application will consume all resources offered to it by Mesos, so we of course urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -680,6 +680,30 @@ See the [configuration page](configuration.html) for information on Spark config driver disconnects, the master immediately tears down the framework. + + spark.mesos.rejectOfferDuration + 120s + + Time to consider unused resources refused, serves as a fallback of + `spark.mesos.rejectOfferDurationForUnmetConstraints`, + `spark.mesos.rejectOfferDurationForReachedMaxCores` + + + + spark.mesos.rejectOfferDurationForUnmetConstraints + spark.mesos.rejectOfferDuration + + Time to consider unused resources refused with unmet constraints + + + + spark.mesos.rejectOfferDurationForReachedMaxCores + spark.mesos.rejectOfferDuration + + Time to consider unused resources refused when maximum number of cores + spark.cores.max is reached + + # Troubleshooting and Debugging From ee571d79e52dc7e0ad7ae80619c12a5c0b90b5a5 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2017 14:33:08 +0900 Subject: [PATCH 201/712] [SPARK-22466][SPARK SUBMIT] export SPARK_CONF_DIR while conf is default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? We use SPARK_CONF_DIR to switch spark conf directory and can be visited if we explicitly export it in spark-env.sh, but with default settings, it can't be done. This PR export SPARK_CONF_DIR while it is default. ### Before ``` KentKentsMacBookPro  ~/Documents/spark-packages/spark-2.3.0-SNAPSHOT-bin-master  bin/spark-shell --master local Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 17/11/08 10:28:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 17/11/08 10:28:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. Spark context Web UI available at http://169.254.168.63:4041 Spark context available as 'sc' (master = local, app id = local-1510108125770). Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.3.0-SNAPSHOT /_/ Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_65) Type in expressions to have them evaluated. Type :help for more information. scala> sys.env.get("SPARK_CONF_DIR") res0: Option[String] = None ``` ### After ``` scala> sys.env.get("SPARK_CONF_DIR") res0: Option[String] = Some(/Users/Kent/Documents/spark/conf) ``` ## How was this patch tested? vanzin Author: Kent Yao Closes #19688 from yaooqinn/SPARK-22466. --- bin/load-spark-env.cmd | 12 +++++------- bin/load-spark-env.sh | 9 +++------ conf/spark-env.sh.template | 3 ++- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index f946197b02d55..cefa513b6fb77 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -19,15 +19,13 @@ rem rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's -rem conf/ subdirectory. +rem conf\ subdirectory. if [%SPARK_ENV_LOADED%] == [] ( set SPARK_ENV_LOADED=1 - if not [%SPARK_CONF_DIR%] == [] ( - set user_conf_dir=%SPARK_CONF_DIR% - ) else ( - set user_conf_dir=..\conf + if [%SPARK_CONF_DIR%] == [] ( + set SPARK_CONF_DIR=%~dp0..\conf ) call :LoadSparkEnv @@ -54,6 +52,6 @@ if [%SPARK_SCALA_VERSION%] == [] ( exit /b 0 :LoadSparkEnv -if exist "%user_conf_dir%\spark-env.cmd" ( - call "%user_conf_dir%\spark-env.cmd" +if exist "%SPARK_CONF_DIR%\spark-env.cmd" ( + call "%SPARK_CONF_DIR%\spark-env.cmd" ) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index d05d94e68c81b..0b5006dbd63ac 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -29,15 +29,12 @@ fi if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 - # Returns the parent of the directory this script lives in. - parent_dir="${SPARK_HOME}" + export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}" - user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" - - if [ -f "${user_conf_dir}/spark-env.sh" ]; then + if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then # Promote all variable declarations to environment (exported) variables set -a - . "${user_conf_dir}/spark-env.sh" + . "${SPARK_CONF_DIR}/spark-env.sh" set +a fi fi diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index f8c895f5303b9..bc92c78f0f8f3 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -32,7 +32,8 @@ # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos -# Options read in YARN client mode +# Options read in YARN client/cluster mode +# - SPARK_CONF_DIR, Alternate conf dir. (Default: ${SPARK_HOME}/conf) # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). From d01044233c7065a20ef7f8e7ab052b88eb34eaa5 Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 9 Nov 2017 14:44:39 +0900 Subject: [PATCH 202/712] [SPARK-22456][SQL] Add support for dayofweek function ## What changes were proposed in this pull request? This PR adds support for a new function called `dayofweek` that returns the day of the week of the given argument as an integer value in the range 1-7, where 1 represents Sunday. ## How was this patch tested? Unit tests and manual tests. Author: ptkool Closes #19672 from ptkool/day_of_week_function. --- python/pyspark/sql/functions.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 7 +++++++ .../main/scala/org/apache/spark/sql/functions.scala | 7 +++++++ 3 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 39815497f3956..087ce7caa89c8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -887,6 +887,19 @@ def month(col): return Column(sc._jvm.functions.month(_to_java_column(col))) +@since(2.3) +def dayofweek(col): + """ + Extract the day of the week of a given date as integer. + + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofweek('dt').alias('day')).collect() + [Row(day=4)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.dayofweek(_to_java_column(col))) + + @since(1.5) def dayofmonth(col): """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb0d4e29a5978..4819f629c5310 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1765,6 +1765,13 @@ def test_datetime_at_epoch(self): self.assertEqual(first['date'], epoch) self.assertEqual(first['lit_date'], epoch) + def test_dayofweek(self): + from pyspark.sql.functions import dayofweek + dt = datetime.datetime(2017, 11, 6) + df = self.spark.createDataFrame([Row(date=dt)]) + row = df.select(dayofweek(df.date)).first() + self.assertEqual(row[0], 2) + def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6bbdfa3ad1893..3e4659b9eae60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2601,6 +2601,13 @@ object functions { */ def month(e: Column): Column = withExpr { Month(e.expr) } + /** + * Extracts the day of the week as an integer from a given date/timestamp/string. + * @group datetime_funcs + * @since 2.3.0 + */ + def dayofweek(e: Column): Column = withExpr { DayOfWeek(e.expr) } + /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs From 695647bf2ebda56f9effb7fcdd875490132ea012 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 9 Nov 2017 15:00:31 +0900 Subject: [PATCH 203/712] [SPARK-21640][SQL][PYTHON][R][FOLLOWUP] Add errorifexists in SparkR and other documentations ## What changes were proposed in this pull request? This PR proposes to add `errorifexists` to SparkR API and fix the rest of them describing the mode, mainly, in API documentations as well. This PR also replaces `convertToJSaveMode` to `setWriteMode` so that string as is is passed to JVM and executes: https://github.com/apache/spark/blob/b034f2565f72aa73c9f0be1e49d148bb4cf05153/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala#L72-L82 and remove the duplication here: https://github.com/apache/spark/blob/3f958a99921d149fb9fdf7ba7e78957afdad1405/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala#L187-L194 ## How was this patch tested? Manually checked the built documentation. These were mainly found by `` grep -r `error` `` and `grep -r 'error'`. Also, unit tests added in `test_sparkSQL.R`. Author: hyukjinkwon Closes #19673 from HyukjinKwon/SPARK-21640-followup. --- R/pkg/R/DataFrame.R | 79 +++++++++++-------- R/pkg/R/utils.R | 9 --- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++ R/pkg/tests/fulltests/test_utils.R | 8 -- python/pyspark/sql/readwriter.py | 25 +++--- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 9 --- 7 files changed, 71 insertions(+), 69 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 763c8d2548580..b8d732a485862 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -58,14 +58,23 @@ setMethod("initialize", "SparkDataFrame", function(.Object, sdf, isCached) { #' Set options/mode and then return the write object #' @noRd setWriteOptions <- function(write, path = NULL, mode = "error", ...) { - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - jmode <- convertToJSaveMode(mode) - write <- callJMethod(write, "mode", jmode) - write <- callJMethod(write, "options", options) - write + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + write <- setWriteMode(write, mode) + write <- callJMethod(write, "options", options) + write +} + +#' Set mode and then return the write object +#' @noRd +setWriteMode <- function(write, mode) { + if (!is.character(mode)) { + stop("mode should be character or omitted. It is 'error' by default.") + } + write <- handledCallJMethod(write, "mode", mode) + write } #' @export @@ -556,9 +565,8 @@ setMethod("registerTempTable", setMethod("insertInto", signature(x = "SparkDataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { - jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, ifelse(overwrite, "overwrite", "append")) invisible(callJMethod(write, "insertInto", tableName)) }) @@ -810,7 +818,8 @@ setMethod("toJSON", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -841,7 +850,8 @@ setMethod("write.json", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -872,7 +882,8 @@ setMethod("write.orc", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -917,7 +928,8 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkDataFrame #' @param path The directory where the file is saved -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2871,18 +2883,19 @@ setMethod("except", #' Additionally, mode is used to specify the behavior of the save operation when data already #' exists in the data source. There are four modes: #' \itemize{ -#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. -#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' \item 'append': Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. -#' \item error: An exception is expected to be thrown. -#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' \item 'error' or 'errorifexists': An exception is expected to be thrown. +#' \item 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. #' } #' #' @param df a SparkDataFrame. #' @param path a name for the table. #' @param source a name for external data source. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2940,17 +2953,18 @@ setMethod("saveDF", #' #' Additionally, mode is used to specify the behavior of the save operation when #' data already exists in the data source. There are four modes: \cr -#' append: Contents of this SparkDataFrame are expected to be appended to existing data. \cr -#' overwrite: Existing data is expected to be overwritten by the contents of this +#' 'append': Contents of this SparkDataFrame are expected to be appended to existing data. \cr +#' 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. \cr -#' error: An exception is expected to be thrown. \cr -#' ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' 'error' or 'errorifexists': An exception is expected to be thrown. \cr +#' 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. \cr #' #' @param df a SparkDataFrame. #' @param tableName a name for the table. #' @param source a name for external data source. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional option(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2972,12 +2986,11 @@ setMethod("saveAsTable", if (is.null(source)) { source <- getDefaultSqlSource() } - jmode <- convertToJSaveMode(mode) options <- varargsToStrEnv(...) write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, mode) write <- callJMethod(write, "options", options) invisible(callJMethod(write, "saveAsTable", tableName)) }) @@ -3544,18 +3557,19 @@ setMethod("histogram", #' Also, mode is used to specify the behavior of the save operation when #' data already exists in the data source. There are four modes: #' \itemize{ -#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. -#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' \item 'append': Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item 'overwrite': Existing data is expected to be overwritten by the contents of this #' SparkDataFrame. -#' \item error: An exception is expected to be thrown. -#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' \item 'error' or 'errorifexists': An exception is expected to be thrown. +#' \item 'ignore': The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. #' } #' #' @param x a SparkDataFrame. #' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}. #' @param tableName yhe name of the table in the external database. -#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' +#' save mode (it is 'error' by default) #' @param ... additional JDBC database connection properties. #' @family SparkDataFrame functions #' @rdname write.jdbc @@ -3572,10 +3586,9 @@ setMethod("histogram", setMethod("write.jdbc", signature(x = "SparkDataFrame", url = "character", tableName = "character"), function(x, url, tableName, mode = "error", ...) { - jmode <- convertToJSaveMode(mode) jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") - write <- callJMethod(write, "mode", jmode) + write <- setWriteMode(write, mode) invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 4b716995f2c46..fa4099231ca8d 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -736,15 +736,6 @@ splitString <- function(input) { Filter(nzchar, unlist(strsplit(input, ",|\\s"))) } -convertToJSaveMode <- function(mode) { - allModes <- c("append", "overwrite", "error", "ignore") - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') # nolint - } - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) - jmode -} - varargsToJProperties <- function(...) { pairs <- list(...) props <- newJObject("java.util.Properties") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0c8118a7c73f3..a0dbd475f78e6 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -630,6 +630,10 @@ test_that("read/write json files", { jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") write.df(df, jsonPath2, "json", mode = "overwrite") + # Test errorifexists + expect_error(write.df(df, jsonPath2, "json", mode = "errorifexists"), + "analysis error - path file:.*already exists") + # Test write.json jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") write.json(df, jsonPath3) @@ -1371,6 +1375,9 @@ test_that("test HiveContext", { expect_equal(count(df5), 3) unlink(parquetDataPath) + # Invalid mode + expect_error(saveAsTable(df, "parquetest", "parquet", mode = "abc", path = parquetDataPath), + "illegal argument - Unknown save mode: abc") unsetHiveContext() } }) @@ -3303,6 +3310,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume "Error in orc : analysis error - path file:.*already exists") expect_error(write.parquet(df, jsonPath), "Error in parquet : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath, mode = 123), "mode should be character or omitted.") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index af81423aa8dd0..fb394b8069c1c 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -158,14 +158,6 @@ test_that("varargsToJProperties", { expect_equal(callJMethod(jprops, "size"), 0L) }) -test_that("convertToJSaveMode", { - s <- convertToJSaveMode("error") - expect_true(class(s) == "jobj") - expect_match(capture.output(print.jobj(s)), "Java ref type org.apache.spark.sql.SaveMode id ") - expect_error(convertToJSaveMode("foo"), - 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint -}) - test_that("captureJVMException", { method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3d87567ab673d..a75bdf8078dd5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -540,7 +540,7 @@ def mode(self, saveMode): * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. + * `error` or `errorifexists`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) @@ -675,7 +675,8 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param options: all other string options @@ -713,12 +714,13 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. + * `error` or `errorifexists`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. :param name: the table name :param format: the format used to save - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: one of `append`, `overwrite`, `error`, `errorifexists`, `ignore` \ + (default: error) :param partitionBy: names of partitioning columns :param options: all other string options """ @@ -741,7 +743,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). @@ -771,7 +774,8 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, gzip, and lzo). @@ -814,7 +818,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, @@ -874,7 +879,8 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, zlib, and lzo). @@ -905,7 +911,8 @@ def jdbc(self, url, table, mode=None, properties=None): * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. + * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \ + exists. :param properties: a dictionary of JDBC database connection arguments. Normally at least properties "user" and "password" with their corresponding values. For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8d95b24c00619..e3fa2ced760e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -65,7 +65,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * - `overwrite`: overwrite the existing data. * - `append`: append the data. * - `ignore`: ignore the operation (i.e. no-op). - * - `error`: default option, throw an exception at runtime. + * - `error` or `errorifexists`: default option, throw an exception at runtime. * * @since 1.4.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 872ef773e8a3a..af20764f9a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -184,15 +184,6 @@ private[sql] object SQLUtils extends Logging { colArray } - def saveMode(mode: String): SaveMode = { - mode match { - case "append" => SaveMode.Append - case "overwrite" => SaveMode.Overwrite - case "error" => SaveMode.ErrorIfExists - case "ignore" => SaveMode.Ignore - } - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' => From 98be55c0fafc3577ab4b106316666b6807bc928f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2017 16:34:38 +0900 Subject: [PATCH 204/712] [SPARK-22222][CORE][TEST][FOLLOW-UP] Remove redundant and deprecated `Timeouts` ## What changes were proposed in this pull request? Since SPARK-21939, Apache Spark uses `TimeLimits` instead of the deprecated `Timeouts`. This PR fixes the build warning `BufferHolderSparkSubmitSuite.scala` introduced at [SPARK-22222](https://github.com/apache/spark/pull/19460/files#diff-d8cf6e0c229969db94ec8ffc31a9239cR36) by removing the redundant `Timeouts`. ```scala trait Timeouts in package concurrent is deprecated: Please use org.scalatest.concurrent.TimeLimits instead [warn] with Timeouts { ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #19697 from dongjoon-hyun/SPARK-22222. --- .../expressions/codegen/BufferHolderSparkSubmitSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 1167d2f3f3891..85682cf6ea670 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.Timeouts import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -32,8 +31,7 @@ class BufferHolderSparkSubmitSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { + with ResetSystemProperties { test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) From c755b0d910d68e7921807f2f2ac1e3fac7a8f357 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2017 09:22:33 +0100 Subject: [PATCH 205/712] [SPARK-22463][YARN][SQL][HIVE] add hadoop/hive/hbase/etc configuration files in SPARK_CONF_DIR to distribute archive ## What changes were proposed in this pull request? When I ran self contained sql apps, such as ```scala import org.apache.spark.sql.SparkSession object ShowHiveTables { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("Show Hive Tables") .enableHiveSupport() .getOrCreate() spark.sql("show tables").show() spark.stop() } } ``` with **yarn cluster** mode and `hive-site.xml` correctly within `$SPARK_HOME/conf`,they failed to connect the right hive metestore for not seeing hive-site.xml in AM/Driver's classpath. Although submitting them with `--files/--jars local/path/to/hive-site.xml` or puting it to `$HADOOP_CONF_DIR/YARN_CONF_DIR` can make these apps works well in cluster mode as client mode, according to the official doc, see http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables > Configuration of Hive is done by placing your hive-site.xml, core-site.xml (for security configuration), and hdfs-site.xml (for HDFS configuration) file in conf/. We may respect these configuration files too or modify the doc for hive-tables in cluster mode. ## How was this patch tested? cc cloud-fan gatorsmile Author: Kent Yao Closes #19663 from yaooqinn/SPARK-21888. --- .../org/apache/spark/deploy/yarn/Client.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fe25c4ddaabf..99e7d46ca5c96 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} +import java.io.{FileSystem => _, _} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -687,6 +687,19 @@ private[spark] class Client( private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() + // SPARK_CONF_DIR shows up in the classpath before HADOOP_CONF_DIR/YARN_CONF_DIR + sys.env.get("SPARK_CONF_DIR").foreach { localConfDir => + val dir = new File(localConfDir) + if (dir.isDirectory) { + val files = dir.listFiles(new FileFilter { + override def accept(pathname: File): Boolean = { + pathname.isFile && pathname.getName.endsWith(".xml") + } + }) + files.foreach { f => hadoopConfFiles(f.getName) = f } + } + } + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) From fe93c0bf6115a62486ebf9f494da5dc368c24418 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Thu, 9 Nov 2017 11:46:01 +0100 Subject: [PATCH 206/712] [DOC] update the API doc and modify the stage API description ## What changes were proposed in this pull request? **1.stage api modify the description format** A list of all stages for a given application.
    ?status=[active|complete|pending|failed] list only stages in the state. content should be included in fix before: ![1](https://user-images.githubusercontent.com/26266482/31753100-201f3432-b4c1-11e7-9e8d-54b62b96c17f.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/31753102-23b174de-b4c1-11e7-96ad-fd79d10440b9.png) **2.add version api doc '/api/v1/version' in monitoring.md** fix after: ![3](https://user-images.githubusercontent.com/26266482/31753087-0fd3a036-b4c1-11e7-802f-a6dc86a2a4b0.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #19532 from guoxiaolongzte/SPARK-22311. --- docs/monitoring.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 1ae43185d22f8..f8d3ce91a0691 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -311,8 +311,10 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/stages - A list of all stages for a given application. -
    ?status=[active|complete|pending|failed] list only stages in the state. + + A list of all stages for a given application. +
    ?status=[active|complete|pending|failed] list only stages in the state. + /applications/[app-id]/stages/[stage-id] @@ -398,7 +400,11 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/environment Environment details of the given application. - + + + /version + Get the current spark version. + The number of jobs and stages which can retrieved is constrained by the same retention From 40a8aefaf3e97e80b23fb05d4afdcc30e1922312 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Nov 2017 11:54:50 +0100 Subject: [PATCH 207/712] [SPARK-22442][SQL] ScalaReflection should produce correct field names for special characters ## What changes were proposed in this pull request? For a class with field name of special characters, e.g.: ```scala case class MyType(`field.1`: String, `field 2`: String) ``` Although we can manipulate DataFrame/Dataset, the field names are encoded: ```scala scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: string] scala> df.as[MyType].collect res7: Array[MyType] = Array(MyType(a,b), MyType(c,d)) ``` It causes resolving problem when we try to convert the data with non-encoded field names: ```scala spark.read.json(path).as[MyType] ... [info] org.apache.spark.sql.AnalysisException: cannot resolve '`field$u002E1`' given input columns: [field 2, fie ld.1]; [info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) ... ``` We should use decoded field name in Dataset schema. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #19664 from viirya/SPARK-22442. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++++---- .../catalyst/expressions/objects/objects.scala | 11 +++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 18 +++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 17e595f9c5d8d..f62553ddd3971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection { def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + .getOrElse(UnresolvedAttribute.quoted(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -675,7 +675,7 @@ object ScalaReflection extends ScalaReflection { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - constructParams(t).map(_.name.toString) + constructParams(t).map(_.name.decodedName.toString) } /** @@ -855,11 +855,12 @@ trait ScalaReflection { // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) if (actualTypeArgs.nonEmpty) { params.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + p.name.decodedName.toString -> + p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } else { params.map { p => - p.name.toString -> p.typeSignature + p.name.decodedName.toString -> p.typeSignature } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6ae3490a3f863..f2eee991c9865 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -214,11 +215,13 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val encodedFunctionName = TermName(functionName).encodedName.toString + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find(_.getName == functionName) + val m = cls.getMethods.find(_.getName == encodedFunctionName) if (m.isEmpty) { - sys.error(s"Couldn't find $functionName on $cls") + sys.error(s"Couldn't find $encodedFunctionName on $cls") } else { m } @@ -247,7 +250,7 @@ case class Invoke( } val evaluate = if (returnPrimitive) { - getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") + getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") } else { val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still @@ -265,7 +268,7 @@ case class Invoke( } s""" Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} $assignResult """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index a5b9855e959d4..f77af5db3279b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -79,6 +80,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String) + object TestingUDT { @SQLUserDefinedType(udt = classOf[NestedStructUDT]) class NestedStruct(val a: Integer, val b: Long, val c: Double) @@ -335,4 +338,17 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } + test("SPARK-22442: Generate correct field names for special characters") { + val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData] + assert(serializer.dataType(0).name == "field.1") + assert(serializer.dataType(1).name == "field 2") + + val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { + case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + }} + assert(argumentsFields(0) == Seq("field.1")) + assert(argumentsFields(1) == Seq("field 2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 1537ce3313c09..c67165c7abca6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1398,6 +1398,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val actual = kvDataset.toString assert(expected === actual) } + + test("SPARK-22442: Generate correct field names for special characters") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = """{"field.1": 1, "field 2": 2}""" + Seq(data).toDF().repartition(1).write.text(path) + val ds = spark.read.json(path).as[SpecialCharClass] + checkDataset(ds, SpecialCharClass("1", "2")) + } + } } case class SingleData(id: Int) @@ -1487,3 +1497,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA) case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) + +case class SpecialCharClass(`field.1`: String, `field 2`: String) From 6793a3dac0a44570625044e1eb30fa578fa2f142 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 9 Nov 2017 11:57:56 +0100 Subject: [PATCH 208/712] [SPARK-22405][SQL] Add new alter table and alter database related ExternalCatalogEvent ## What changes were proposed in this pull request? We're building a data lineage tool in which we need to monitor the metadata changes in ExternalCatalog, current ExternalCatalog already provides several useful events like "CreateDatabaseEvent" for custom SparkListener to use. But still there's some event missing, like alter database event and alter table event. So here propose to and new ExternalCatalogEvent. ## How was this patch tested? Enrich the current UT and tested on local cluster. CC hvanhovell please let me know your comments about current proposal, thanks. Author: jerryshao Closes #19649 from jerryshao/SPARK-22405. --- .../catalyst/catalog/ExternalCatalog.scala | 35 +++++++++++++++-- .../catalyst/catalog/InMemoryCatalog.scala | 8 ++-- .../spark/sql/catalyst/catalog/events.scala | 38 ++++++++++++++++++- .../catalog/ExternalCatalogEventSuite.scala | 22 +++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 18 ++++++--- 5 files changed, 107 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 223094d485936..45b4f013620c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -87,7 +87,14 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - def alterDatabase(dbDefinition: CatalogDatabase): Unit + final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + doAlterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -147,7 +154,15 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - def alterTable(tableDefinition: CatalogTable): Unit + final def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + doAlterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + protected def doAlterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -158,10 +173,22 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + doAlterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + doAlterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9504140d51e99..8eacfa058bd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def alterTableDataSchema( + override def doAlterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def alterTableStats( + override def doAlterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala index 742a51e640383..e7d41644392d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -61,6 +61,16 @@ case class DropDatabasePreEvent(database: String) extends DatabaseEvent */ case class DropDatabaseEvent(database: String) extends DatabaseEvent +/** + * Event fired before a database is altered. + */ +case class AlterDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database is altered. + */ +case class AlterDatabaseEvent(database: String) extends DatabaseEvent + /** * Event fired when a table is created, dropped or renamed. */ @@ -110,7 +120,33 @@ case class RenameTableEvent( extends TableEvent /** - * Event fired when a function is created, dropped or renamed. + * String to indicate which part of table is altered. If a plain alterTable API is called, then + * type will generally be Table. + */ +object AlterTableKind extends Enumeration { + val TABLE = "table" + val DATASCHEMA = "dataSchema" + val STATS = "stats" +} + +/** + * Event fired before a table is altered. + */ +case class AlterTablePreEvent( + database: String, + name: String, + kind: String) extends TableEvent + +/** + * Event fired after a table is altered. + */ +case class AlterTableEvent( + database: String, + name: String, + kind: String) extends TableEvent + +/** + * Event fired when a function is created, dropped, altered or renamed. */ trait FunctionEvent extends DatabaseEvent { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 087c26f23f383..1acbe34d9a075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -75,6 +75,11 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(CreateDatabasePreEvent("db5") :: Nil) + // ALTER + val newDbDefinition = dbDefinition.copy(description = "test") + catalog.alterDatabase(newDbDefinition) + checkEvents(AlterDatabasePreEvent("db5") :: AlterDatabaseEvent("db5") :: Nil) + // DROP intercept[AnalysisException] { catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) @@ -119,6 +124,23 @@ class ExternalCatalogEventSuite extends SparkFunSuite { } checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + // ALTER + val newTableDefinition = tableDefinition.copy(tableType = CatalogTableType.EXTERNAL) + catalog.alterTable(newTableDefinition) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.TABLE) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.TABLE) :: Nil) + + // ALTER schema + val newSchema = new StructType().add("id", "long", nullable = false) + catalog.alterTableDataSchema("db5", "tbl1", newSchema) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.DATASCHEMA) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.DATASCHEMA) :: Nil) + + // ALTER stats + catalog.alterTableStats("db5", "tbl1", None) + checkEvents(AlterTablePreEvent("db5", "tbl1", AlterTableKind.STATS) :: + AlterTableEvent("db5", "tbl1", AlterTableKind.STATS) :: Nil) + // RENAME catalog.renameTable("db5", "tbl1", "tbl2") checkEvents( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f8a947bf527e7..7cd772544a96a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def alterTable(tableDefinition: CatalogTable): Unit = withClient { + override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -619,8 +619,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableDataSchema( - db: String, table: String, newDataSchema: StructType): Unit = withClient { + /** + * Alter the data schema of a table identified by the provided database and table name. The new + * data schema should not have conflict column names with the existing partition columns, and + * should still contain all the existing data columns. + */ + override def doAlterTableDataSchema( + db: String, + table: String, + newDataSchema: StructType): Unit = withClient { requireTableExists(db, table) val oldTable = getTable(db, table) verifyDataSchema(oldTable.identifier, oldTable.tableType, newDataSchema) @@ -648,7 +655,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def alterTableStats( + /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ + override def doAlterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { From 77f74539ec7a445e24736029fb198b48ffd50ea9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Nov 2017 16:35:06 +0200 Subject: [PATCH 209/712] [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns ## What changes were proposed in this pull request? Current ML's Bucketizer can only bin a column of continuous features. If a dataset has thousands of of continuous columns needed to bin, we will result in thousands of ML stages. It is inefficient regarding query planning and execution. We should have a type of bucketizer that can bin a lot of columns all at once. It would need to accept an list of arrays of split points to correspond to the columns to bin, but it might make things more efficient by replacing thousands of stages with just one. This current approach in this patch is to add a new `MultipleBucketizerInterface` for this purpose. `Bucketizer` now extends this new interface. ### Performance Benchmarking using the test dataset provided in JIRA SPARK-20392 (blockbuster.csv). The ML pipeline includes 2 `StringIndexer`s and 1 `MultipleBucketizer` or 137 `Bucketizer`s to bin 137 input columns with the same splits. Then count the time to transform the dataset. MultipleBucketizer: 3352 ms Bucketizer: 51512 ms ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17819 from viirya/SPARK-20542. --- .../examples/ml/JavaBucketizerExample.java | 41 +++ .../spark/examples/ml/BucketizerExample.scala | 36 ++- .../apache/spark/ml/feature/Bucketizer.scala | 122 +++++++-- .../org/apache/spark/ml/param/params.scala | 39 +++ .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 17 ++ .../spark/ml/feature/JavaBucketizerSuite.java | 35 +++ .../spark/ml/feature/BucketizerSuite.scala | 239 +++++++++++++++++- .../apache/spark/ml/param/ParamsSuite.scala | 38 ++- .../scala/org/apache/spark/sql/Dataset.scala | 21 +- .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++ 11 files changed, 592 insertions(+), 25 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java index f00993833321d..3e49bf04ac892 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -33,6 +33,13 @@ import org.apache.spark.sql.types.StructType; // $example off$ +/** + * An example for Bucketizer. + * Run with + *
    + * bin/run-example ml.JavaBucketizerExample
    + * 
    + */ public class JavaBucketizerExample { public static void main(String[] args) { SparkSession spark = SparkSession @@ -68,6 +75,40 @@ public static void main(String[] args) { bucketedData.show(); // $example off$ + // $example on$ + // Bucketize multiple columns at one pass. + double[][] splitsArray = { + {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}, + {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY} + }; + + List data2 = Arrays.asList( + RowFactory.create(-999.9, -999.9), + RowFactory.create(-0.5, -0.2), + RowFactory.create(-0.3, -0.1), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.4), + RowFactory.create(999.9, 999.9) + ); + StructType schema2 = new StructType(new StructField[]{ + new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features2", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset dataFrame2 = spark.createDataFrame(data2, schema2); + + Bucketizer bucketizer2 = new Bucketizer() + .setInputCols(new String[] {"features1", "features2"}) + .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"}) + .setSplitsArray(splitsArray); + // Transform original data into its bucket index. + Dataset bucketedData2 = bucketizer2.transform(dataFrame2); + + System.out.println("Bucketizer output with [" + + (bucketizer2.getSplitsArray()[0].length-1) + ", " + + (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column"); + bucketedData2.show(); + // $example off$ + spark.stop(); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 04e4eccd436ed..7e65f9c88907d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -22,7 +22,13 @@ package org.apache.spark.examples.ml import org.apache.spark.ml.feature.Bucketizer // $example off$ import org.apache.spark.sql.SparkSession - +/** + * An example for Bucketizer. + * Run with + * {{{ + * bin/run-example ml.BucketizerExample + * }}} + */ object BucketizerExample { def main(args: Array[String]): Unit = { val spark = SparkSession @@ -48,6 +54,34 @@ object BucketizerExample { bucketedData.show() // $example off$ + // $example on$ + val splitsArray = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity)) + + val data2 = Array( + (-999.9, -999.9), + (-0.5, -0.2), + (-0.3, -0.1), + (0.0, 0.0), + (0.2, 0.4), + (999.9, 999.9)) + val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2") + + val bucketizer2 = new Bucketizer() + .setInputCols(Array("features1", "features2")) + .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2")) + .setSplitsArray(splitsArray) + + // Transform original data into its bucket index. + val bucketedData2 = bucketizer2.transform(dataFrame2) + + println(s"Bucketizer output with [" + + s"${bucketizer2.getSplitsArray(0).length-1}, " + + s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column") + bucketedData2.show() + // $example off$ + spark.stop() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6a11a75d1d569..e07f2a107badb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.expressions.UserDefinedFunction @@ -32,12 +32,16 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that + * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and + * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is + * only used for single column usage, and `splitsArray` is for multiple columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol - with DefaultParamsWritable { + with HasInputCols with HasOutputCols with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("bucketizer")) @@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special - * additional bucket). + * additional bucket). Note that in the multiple column case, the invalid handling is applied + * to all columns. That said for 'error' it will throw an error if any invalids are found in + * any column, for 'skip' it will skip rows with any invalids in any columns, etc. * Default: "error" * @group param */ @@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + /** + * Parameter for specifying multiple splits parameters. Each element in this array can be used to + * map continuous features into buckets. + * + * @group param + */ + @Since("2.3.0") + val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray", + "The array of split points for mapping continuous features into buckets for multiple " + + "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " + + "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + + "explicitly provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + Bucketizer.checkSplitsArray) + + /** @group getParam */ + @Since("2.3.0") + def getSplitsArray: Array[Array[Double]] = $(splitsArray) + + /** @group setParam */ + @Since("2.3.0") + def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Determines whether this `Bucketizer` is going to map multiple columns. If and only if + * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified + * by `inputCol`. A warning will be printed if both are set. + */ + private[feature] def isBucketizeMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`Bucketizer` only map one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + val transformedSchema = transformSchema(dataset.schema) + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset @@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val bucketizer: UserDefinedFunction = udf { (feature: Double) => - Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - }.withName("bucketizer") + val seqOfSplits = if (isBucketizeMultipleColumns()) { + $(splitsArray).toSeq + } else { + Seq($(splits)) + } - val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) - val newField = prepOutputField(filteredDataset.schema) - filteredDataset.withColumn($(outputCol), newCol, newField.metadata) + val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) => + udf { (feature: Double) => + Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid) + }.withName(s"bucketizer_$idx") + } + + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => + bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) + } + val metadata = outputColumns.map { col => + transformedSchema(col).metadata + } + filteredDataset.withColumns(outputColumns, newCols, metadata) } - private def prepOutputField(schema: StructType): StructField = { - val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray - val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), + private def prepOutputField(splits: Array[Double], outputCol: String): StructField = { + val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray + val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true), values = Some(buckets)) attr.toStructField() } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - SchemaUtils.appendColumn(schema, prepOutputField(schema)) + if (isBucketizeMultipleColumns()) { + var transformedSchema = schema + $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + SchemaUtils.checkNumericType(transformedSchema, inputCol) + transformedSchema = SchemaUtils.appendColumn(transformedSchema, + prepOutputField($(splitsArray)(idx), outputCol)) + } + transformedSchema + } else { + SchemaUtils.checkNumericType(schema, $(inputCol)) + SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) + } } @Since("1.4.1") @@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + /** + * Check each splits in the splits array. + */ + private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = { + splitsArray.forall(checkSplits(_)) + } + /** * Binary searching in several buckets to place each data point. * @param splits array of split points diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ac68b825af537..8985f2af90a9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array } } +/** + * :: DeveloperApi :: + * Specialized version of `Param[Array[Array[Double]]]` for Java. + */ +@DeveloperApi +class DoubleArrayArrayParam( + parent: Params, + name: String, + doc: String, + isValid: Array[Array[Double]] => Boolean) + extends Param[Array[Array[Double]]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a `java.util.List` of values (for Java and Python). */ + def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] = + w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray) + + override def jsonEncode(value: Array[Array[Double]]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode)))) + } + + override def jsonDecode(json: String): Array[Array[Double]] = { + parse(json) match { + case JArray(values) => + values.map { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + }.toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + } + } +} + /** * :: DeveloperApi :: * Specialized version of `Param[Array[Int]]` for Java. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a932d28fadbd8..20a1db854e3a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e6bdf5236e72d..0d5fb28ae783c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -257,6 +257,23 @@ trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. This trait may be changed or * removed between minor versions. diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 87639380bdcf4..e65265bf74a88 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -61,4 +61,39 @@ public void bucketizerTest() { Assert.assertTrue((index >= 0) && (index <= 1)); } } + + @Test + public void bucketizerMultipleColumnsTest() { + double[][] splitsArray = { + {-0.5, 0.0, 0.5}, + {-0.5, 0.0, 0.2, 0.5} + }; + + StructType schema = new StructType(new StructField[]{ + new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()), + }); + Dataset dataset = spark.createDataFrame( + Arrays.asList( + RowFactory.create(-0.5, -0.5), + RowFactory.create(-0.3, -0.3), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.3)), + schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCols(new String[] {"feature1", "feature2"}) + .setOutputCols(new String[] {"result1", "result2"}) + .setSplitsArray(splitsArray); + + List result = bucketizer.transform(dataset).select("result1", "result2").collectAsList(); + + for (Row r : result) { + double index1 = r.getDouble(0); + Assert.assertTrue((index1 >= 0) && (index1 <= 1)); + + double index2 = r.getDouble(1); + Assert.assertTrue((index2 >= 0) && (index2 <= 2)); + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 420fb17ddce8c..748dbd1b995d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -187,6 +188,220 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } } + + test("multiple columns: Bucket continuous features, without -inf,inf") { + // Check a set of valid feature values. + val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5)) + val validData1 = Array(-0.5, -0.3, 0.0, 0.2) + val validData2 = Array(0.5, 0.3, 0.0, -0.1) + val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0) + val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer1: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer1.isBucketizeMultipleColumns()) + + bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") + BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + // Check for exceptions when using a set of invalid feature values. + val invalidData1 = Array(-0.9) ++ validData1 + val invalidData2 = Array(0.51) ++ validData1 + val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx") + + val bucketizer2: Bucketizer = new Bucketizer() + .setInputCols(Array("feature")) + .setOutputCols(Array("result")) + .setSplitsArray(Array(splits(0))) + + assert(bucketizer2.isBucketizeMultipleColumns()) + + withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF1).collect() + } + } + val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx") + withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF2).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with -inf,inf") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) + val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) + val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + } + + test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) + val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + bucketizer.setHandleInvalid("keep") + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + bucketizer.setHandleInvalid("skip") + val skipResults1: Array[Double] = bucketizer.transform(dataFrame) + .select("result1").as[Double].collect() + assert(skipResults1.length === 7) + assert(skipResults1.forall(_ !== 4.0)) + + val skipResults2: Array[Double] = bucketizer.transform(dataFrame) + .select("result2").as[Double].collect() + assert(skipResults2.length === 7) + assert(skipResults2.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with NaN splits") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) + withClue("Invalid NaN split was not caught during Bucketizer initialization") { + intercept[IllegalArgumentException] { + new Bucketizer().setSplitsArray(Array(splits)) + } + } + } + + test("multiple columns: read/write") { + val t = new Bucketizer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) + assert(t.isBucketizeMultipleColumns()) + testDefaultReadWrite(t) + } + + test("Bucketizer in a pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val bucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + assert(bucket.isBucketizeMultipleColumns()) + + val pl = new Pipeline() + .setStages(Array(bucket)) + .fit(df) + pl.transform(df).select("result1", "expected1", "result2", "expected2") + + BucketizerSuite.checkBucketResults(pl.transform(df), + Seq("result1", "result2"), Seq("expected1", "expected2")) + } + + test("Compare single/multiple column(s) Bucketizer in pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val multiColsBucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsBucket)) + .fit(df) + + val bucketForCol1 = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result1") + .setSplits(Array(-0.5, 0.0, 0.5)) + val bucketForCol2 = new Bucketizer() + .setInputCol("feature2") + .setOutputCol("result2") + .setSplits(Array(-0.5, 0.0, 0.5)) + + val plForSingleCol = new Pipeline() + .setStages(Array(bucketForCol1, bucketForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && + rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) && + rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3)) + } + } + + test("Both inputCol and inputCols are set") { + val bucket = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result") + .setSplits(Array(-0.5, 0.0, 0.5)) + .setInputCols(Array("feature1", "feature2")) + + // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. + assert(bucket.isBucketizeMultipleColumns() == false) + } } private object BucketizerSuite extends SparkFunSuite { @@ -220,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite { i += 1 } } + + /** Checks if bucketized results match expected ones. */ + def checkBucketResults( + bucketResult: DataFrame, + resultColumns: Seq[String], + expectedColumns: Seq[String]): Unit = { + assert(resultColumns.length == expectedColumns.length, + s"Given ${resultColumns.length} result columns doesn't match " + + s"${expectedColumns.length} expected columns.") + assert(resultColumns.length > 0, "At least one result and expected columns are needed.") + + val allColumns = resultColumns ++ expectedColumns + bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach { + case row => + for (idx <- 0 until row.length / 2) { + val result = row.getDouble(idx) + val expected = row.getDouble(idx + row.length / 2) + assert(result === expected, "The feature value is not correct after bucketing. " + + s"Expected $expected but found $result.") + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 78a33e05e0e48..85198ad4c913a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite { { // DoubleArrayParam val param = new DoubleArrayParam(dummy, "name", "doc") val values: Seq[Array[Double]] = Seq( - Array(), - Array(1.0), - Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, - Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) for (value <- values) { val json = param.jsonEncode(value) val decoded = param.jsonDecode(json) @@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite { } } + { // DoubleArrayArrayParam + val param = new DoubleArrayArrayParam(dummy, "name", "doc") + val values: Seq[Array[Array[Double]]] = Seq( + Array(Array()), + Array(Array(1.0)), + Array(Array(1.0), Array(2.0)), + Array( + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity), + Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0, + Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0) + )) + + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actualArray, expectedArray) => + assert(actualArray.length === expectedArray.length) + actualArray.zip(expectedArray).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + } + { // StringArrayParam val param = new StringArrayParam(dummy, "name", "doc") val values: Seq[Array[String]] = Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd99ec52ce93f..5eb2affa0bd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2135,12 +2135,27 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by adding a column with metadata. + * Returns a new Dataset by adding columns with metadata. */ - private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - withColumn(colName, col.as(colName, metadata)) + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + require(colNames.size == metadata.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of metadata elements: ${metadata.size}") + val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => + col.as(colName, metadata) + } + withColumns(colNames, newCols) } + /** + * Returns a new Dataset by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = + withColumns(Seq(colName), Seq(col), Seq(metadata)) + /** * Returns a new Dataset with a column renamed. * This is a no-op if schema doesn't contain existingName. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17c88b0690800..31bfa77e76329 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -686,6 +686,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("withColumns: given metadata") { + def buildMetadata(num: Int): Seq[Metadata] = { + (0 until num).map { n => + val builder = new MetadataBuilder + builder.putLong("key", n.toLong) + builder.build() + } + } + + val df = testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(2)) + + df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) => + assert(col.metadata.getLong("key").toInt === idx) + } + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(1)) + } + assert(err.getMessage.contains( + "The size of column names: 2 isn't equal to the size of metadata elements: 1")) + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) From 6b19c0735d5dc9287c13dfc255eb353d827800e2 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 9 Nov 2017 09:53:41 -0800 Subject: [PATCH 210/712] [MINOR][CORE] Fix nits in MetricsSystemSuite ## What changes were proposed in this pull request? Fixing nits in MetricsSystemSuite file 1) Using Sink instead of Source while casting 2) Using meaningful naming for variables, which reflect their usage ## How was this patch tested? Ran the tests locally and all of them passing Author: Srinivasa Reddy Vundela Closes #19699 from vundela/master. --- .../spark/metrics/MetricsSystemSuite.scala | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 61db6af830cc5..a7a24114f17e2 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.internal.config._ +import org.apache.spark.metrics.sink.Sink import org.apache.spark.metrics.source.{Source, StaticSources} class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ @@ -42,7 +43,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) - val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) + val sinks = PrivateMethod[ArrayBuffer[Sink]]('sinks) assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 0) @@ -53,7 +54,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) - val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) + val sinks = PrivateMethod[ArrayBuffer[Sink]]('sinks) assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 1) @@ -126,9 +127,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appId.$executorId.${source.sourceName}") } @@ -142,9 +143,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -158,9 +159,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.app.id", appId) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -176,9 +177,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = testMetricsSystem.buildRegistryName(source) // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. assert(metricName != s"$appId.$executorId.${source.sourceName}") @@ -200,9 +201,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === s"$appName.$executorId.${source.sourceName}") } @@ -218,9 +219,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, namespaceToResolve) val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) // If the user set the spark.metrics.namespace property to an expansion of another property // (say ${spark.doesnotexist}, the unresolved name (i.e. literally ${spark.doesnotexist}) // is used as the root logger name. @@ -238,9 +239,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set(METRICS_NAMESPACE, "${spark.app.name}") val instanceName = "executor" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val executorMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = executorMetricsSystem.buildRegistryName(source) assert(metricName === source.sourceName) } @@ -259,9 +260,9 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM conf.set("spark.executor.id", executorId) val instanceName = "testInstance" - val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + val testMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) - val metricName = driverMetricsSystem.buildRegistryName(source) + val metricName = testMetricsSystem.buildRegistryName(source) // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. assert(metricName != s"$appId.$executorId.${source.sourceName}") From 6ae12715c7286abc87cc7a22b1fa955e844077d5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 9 Nov 2017 15:46:16 -0600 Subject: [PATCH 211/712] [SPARK-20647][CORE] Port StorageTab to the new UI backend. This required adding information about StreamBlockId to the store, which is not available yet via the API. So an internal type was added until there's a need to expose that information in the API. The UI only lists RDDs that have cached partitions, and that information wasn't being correctly captured in the listener, so that's also fixed, along with some minor (internal) API adjustments so that the UI can get the correct data. Because of the way partitions are cached, some optimizations w.r.t. how often the data is flushed to the store could not be applied to this code; because of that, some different ways to make the code more performant were added to the data structures tracking RDD blocks, with the goal of avoiding expensive copies when lots of blocks are being updated. Tested with existing and updated unit tests. Author: Marcelo Vanzin Closes #19679 from vanzin/SPARK-20647. --- .../spark/status/AppStatusListener.scala | 97 ++++++--- .../apache/spark/status/AppStatusStore.scala | 10 +- .../org/apache/spark/status/LiveEntity.scala | 170 ++++++++++++--- .../spark/status/api/v1/AllRDDResource.scala | 81 +------ .../spark/status/api/v1/OneRDDResource.scala | 10 +- .../org/apache/spark/status/storeTypes.scala | 16 ++ .../spark/storage/BlockStatusListener.scala | 100 --------- .../spark/storage/StorageStatusListener.scala | 111 ---------- .../scala/org/apache/spark/ui/SparkUI.scala | 17 +- .../org/apache/spark/ui/storage/RDDPage.scala | 16 +- .../apache/spark/ui/storage/StoragePage.scala | 78 ++++--- .../apache/spark/ui/storage/StorageTab.scala | 68 +----- .../spark/status/AppStatusListenerSuite.scala | 186 +++++++++++----- .../apache/spark/status/LiveEntitySuite.scala | 68 ++++++ .../storage/BlockStatusListenerSuite.scala | 113 ---------- .../storage/StorageStatusListenerSuite.scala | 167 -------------- .../spark/ui/storage/StoragePageSuite.scala | 128 +++++++---- .../spark/ui/storage/StorageTabSuite.scala | 205 ------------------ project/MimaExcludes.scala | 2 + 19 files changed, 586 insertions(+), 1057 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala create mode 100644 core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 0469c871362c0..7f2c00c09d43d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -146,10 +146,19 @@ private[spark] class AppStatusListener( override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => + val now = System.nanoTime() exec.isActive = false exec.removeTime = new Date(event.time) exec.removeReason = event.reason - update(exec, System.nanoTime()) + update(exec, now) + + // Remove all RDD distributions that reference the removed executor, in case there wasn't + // a corresponding event. + liveRDDs.values.foreach { rdd => + if (rdd.removeDistribution(exec)) { + update(rdd, now) + } + } } } @@ -465,7 +474,8 @@ private[spark] class AppStatusListener( override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { event.blockUpdatedInfo.blockId match { case block: RDDBlockId => updateRDDBlock(event, block) - case _ => // TODO: API only covers RDD storage. + case stream: StreamBlockId => updateStreamBlock(event, stream) + case _ => } } @@ -502,29 +512,47 @@ private[spark] class AppStatusListener( val maybeExec = liveExecutors.get(executorId) var rddBlocksDelta = 0 + // Update the executor stats first, since they are used to calculate the free memory + // on tracked RDD distributions. + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + } + // Update the block entry in the RDD info, keeping track of the deltas above so that we // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + val partition = rdd.partition(block.name) val executors = if (updatedStorageLevel.isDefined) { - if (!partition.executors.contains(executorId)) { + val current = partition.executors + if (current.contains(executorId)) { + current + } else { rddBlocksDelta = 1 + current :+ executorId } - partition.executors + executorId } else { rddBlocksDelta = -1 - partition.executors - executorId + partition.executors.filter(_ != executorId) } // Only update the partition if it's still stored in some executor, otherwise get rid of it. if (executors.nonEmpty) { - if (updatedStorageLevel.isDefined) { - partition.storageLevel = updatedStorageLevel.get - } - partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) - partition.diskUsed = newValue(partition.diskUsed, diskDelta) - partition.executors = executors + partition.update(executors, rdd.storageLevel, + newValue(partition.memoryUsed, memoryDelta), + newValue(partition.diskUsed, diskDelta)) } else { rdd.removePartition(block.name) } @@ -532,42 +560,39 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.rddBlocks + rddBlocksDelta > 0) { val dist = rdd.distribution(exec) - dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) dist.diskUsed = newValue(dist.diskUsed, diskDelta) if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) - dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) } else { dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) - dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) } } + dist.lastUpdate = null } else { rdd.removeDistribution(exec) } - } - if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + // Trigger an update on other RDDs so that the free memory information is updated. + liveRDDs.values.foreach { otherRdd => + if (otherRdd.info.id != block.rddId) { + otherRdd.distributionOpt(exec).foreach { dist => + dist.lastUpdate = null + update(otherRdd, now) + } + } + } } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) update(rdd, now) } + // Finish updating the executor now that we know the delta in the number of blocks. maybeExec.foreach { exec => - if (exec.hasMemoryInfo) { - if (storageLevel.useOffHeap) { - exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) - } else { - exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) - } - } - exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) - exec.diskUsed = newValue(exec.diskUsed, diskDelta) exec.rddBlocks += rddBlocksDelta maybeUpdate(exec, now) } @@ -577,6 +602,26 @@ private[spark] class AppStatusListener( liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId, addTime)) } + private def updateStreamBlock(event: SparkListenerBlockUpdated, stream: StreamBlockId): Unit = { + val storageLevel = event.blockUpdatedInfo.storageLevel + if (storageLevel.isValid) { + val data = new StreamBlockData( + stream.name, + event.blockUpdatedInfo.blockManagerId.executorId, + event.blockUpdatedInfo.blockManagerId.hostPort, + storageLevel.description, + storageLevel.useMemory, + storageLevel.useDisk, + storageLevel.deserialized, + event.blockUpdatedInfo.memSize, + event.blockUpdatedInfo.diskSize) + kvstore.write(data) + } else { + kvstore.delete(classOf[StreamBlockData], + Array(stream.name, event.blockUpdatedInfo.blockManagerId.executorId)) + } + } + private def getOrCreateStage(info: StageInfo): LiveStage = { val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) stage.info = info diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 334407829f9fe..80c8d7d11a3c2 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -209,14 +209,20 @@ private[spark] class AppStatusStore(store: KVStore) { indexed.skip(offset).max(length).asScala.map(_.info).toSeq } - def rddList(): Seq[v1.RDDStorageInfo] = { - store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).toSeq + def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { + store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).filter { rdd => + !cachedOnly || rdd.numCachedPartitions > 0 + }.toSeq } def rdd(rddId: Int): v1.RDDStorageInfo = { store.read(classOf[RDDStorageInfoWrapper], rddId).info } + def streamBlocksList(): Seq[StreamBlockData] = { + store.view(classOf[StreamBlockData]).asScala.toSeq + } + def close(): Unit = { store.close() } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 8c48020e246b4..706d94c3a59b9 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -420,79 +420,101 @@ private class LiveStage extends LiveEntity { private class LiveRDDPartition(val blockName: String) { - var executors = Set[String]() - var storageLevel: String = null - var memoryUsed = 0L - var diskUsed = 0L + // Pointers used by RDDPartitionSeq. + @volatile var prev: LiveRDDPartition = null + @volatile var next: LiveRDDPartition = null + + var value: v1.RDDPartitionInfo = null + + def executors: Seq[String] = value.executors + + def memoryUsed: Long = value.memoryUsed + + def diskUsed: Long = value.diskUsed - def toApi(): v1.RDDPartitionInfo = { - new v1.RDDPartitionInfo( + def update( + executors: Seq[String], + storageLevel: String, + memoryUsed: Long, + diskUsed: Long): Unit = { + value = new v1.RDDPartitionInfo( blockName, storageLevel, memoryUsed, diskUsed, - executors.toSeq.sorted) + executors) } } -private class LiveRDDDistribution(val exec: LiveExecutor) { +private class LiveRDDDistribution(exec: LiveExecutor) { - var memoryRemaining = exec.maxMemory + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L var onHeapUsed = 0L var offHeapUsed = 0L - var onHeapRemaining = 0L - var offHeapRemaining = 0L + + // Keep the last update handy. This avoids recomputing the API view when not needed. + var lastUpdate: v1.RDDDataDistribution = null def toApi(): v1.RDDDataDistribution = { - new v1.RDDDataDistribution( - exec.hostPort, - memoryUsed, - memoryRemaining, - diskUsed, - if (exec.hasMemoryInfo) Some(onHeapUsed) else None, - if (exec.hasMemoryInfo) Some(offHeapUsed) else None, - if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, - if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + if (lastUpdate == null) { + lastUpdate = new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + exec.maxMemory - exec.memoryUsed, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(exec.totalOnHeap - exec.usedOnHeap) else None, + if (exec.hasMemoryInfo) Some(exec.totalOffHeap - exec.usedOffHeap) else None) + } + lastUpdate } } -private class LiveRDD(info: RDDInfo) extends LiveEntity { +private class LiveRDD(val info: RDDInfo) extends LiveEntity { var storageLevel: String = info.storageLevel.description var memoryUsed = 0L var diskUsed = 0L private val partitions = new HashMap[String, LiveRDDPartition]() + private val partitionSeq = new RDDPartitionSeq() + private val distributions = new HashMap[String, LiveRDDDistribution]() def partition(blockName: String): LiveRDDPartition = { - partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + partitions.getOrElseUpdate(blockName, { + val part = new LiveRDDPartition(blockName) + part.update(Nil, storageLevel, 0L, 0L) + partitionSeq.addPartition(part) + part + }) } - def removePartition(blockName: String): Unit = partitions.remove(blockName) + def removePartition(blockName: String): Unit = { + partitions.remove(blockName).foreach(partitionSeq.removePartition) + } def distribution(exec: LiveExecutor): LiveRDDDistribution = { - distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + distributions.getOrElseUpdate(exec.executorId, new LiveRDDDistribution(exec)) } - def removeDistribution(exec: LiveExecutor): Unit = { - distributions.remove(exec.hostPort) + def removeDistribution(exec: LiveExecutor): Boolean = { + distributions.remove(exec.executorId).isDefined } - override protected def doUpdate(): Any = { - val parts = if (partitions.nonEmpty) { - Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) - } else { - None - } + def distributionOpt(exec: LiveExecutor): Option[LiveRDDDistribution] = { + distributions.get(exec.executorId) + } + override protected def doUpdate(): Any = { val dists = if (distributions.nonEmpty) { - Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + Some(distributions.values.map(_.toApi()).toSeq) } else { None } @@ -506,7 +528,7 @@ private class LiveRDD(info: RDDInfo) extends LiveEntity { memoryUsed, diskUsed, dists, - parts) + Some(partitionSeq)) new RDDStorageInfoWrapper(rdd) } @@ -526,7 +548,7 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.map(_.intern()).orNull, + acc.name.orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } @@ -534,3 +556,81 @@ private object LiveEntityHelpers { } } + +/** + * A custom sequence of partitions based on a mutable linked list. + * + * The external interface is an immutable Seq, which is thread-safe for traversal. There are no + * guarantees about consistency though - iteration might return elements that have been removed + * or miss added elements. + * + * Internally, the sequence is mutable, and elements can modify the data they expose. Additions and + * removals are O(1). It is not safe to do multiple writes concurrently. + */ +private class RDDPartitionSeq extends Seq[v1.RDDPartitionInfo] { + + @volatile private var _head: LiveRDDPartition = null + @volatile private var _tail: LiveRDDPartition = null + @volatile var count = 0 + + override def apply(idx: Int): v1.RDDPartitionInfo = { + var curr = 0 + var e = _head + while (curr < idx && e != null) { + curr += 1 + e = e.next + } + if (e != null) e.value else throw new IndexOutOfBoundsException(idx.toString) + } + + override def iterator: Iterator[v1.RDDPartitionInfo] = { + new Iterator[v1.RDDPartitionInfo] { + var current = _head + + override def hasNext: Boolean = current != null + + override def next(): v1.RDDPartitionInfo = { + if (current != null) { + val tmp = current + current = tmp.next + tmp.value + } else { + throw new NoSuchElementException() + } + } + } + } + + override def length: Int = count + + def addPartition(part: LiveRDDPartition): Unit = { + part.prev = _tail + if (_tail != null) { + _tail.next = part + } + if (_head == null) { + _head = part + } + _tail = part + count += 1 + } + + def removePartition(part: LiveRDDPartition): Unit = { + count -= 1 + // Remove the partition from the list, but leave the pointers unchanged. That ensures a best + // effort at returning existing elements when iterations still reference the removed partition. + if (part.prev != null) { + part.prev.next = part.next + } + if (part eq _head) { + _head = part.next + } + if (part.next != null) { + part.next.prev = part.prev + } + if (part eq _tail) { + _tail = part.prev + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 1279b281ad8d8..2189e1da91841 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -21,90 +21,11 @@ import javax.ws.rs.core.MediaType import org.apache.spark.storage.{RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.storage.StorageListener @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllRDDResource(ui: SparkUI) { @GET - def rddList(): Seq[RDDStorageInfo] = { - val storageStatusList = ui.storageListener.activeStorageStatusList - val rddInfos = ui.storageListener.rddInfoList - rddInfos.map{rddInfo => - AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList, - includeDetails = false) - } - } + def rddList(): Seq[RDDStorageInfo] = ui.store.rddList() } - -private[spark] object AllRDDResource { - - def getRDDStorageInfo( - rddId: Int, - listener: StorageListener, - includeDetails: Boolean): Option[RDDStorageInfo] = { - val storageStatusList = listener.activeStorageStatusList - listener.rddInfoList.find { _.id == rddId }.map { rddInfo => - getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails) - } - } - - def getRDDStorageInfo( - rddId: Int, - rddInfo: RDDInfo, - storageStatusList: Seq[StorageStatus], - includeDetails: Boolean): RDDStorageInfo = { - val workers = storageStatusList.map { (rddId, _) } - val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList) - val blocks = storageStatusList - .flatMap { _.rddBlocksById(rddId) } - .sortWith { _._1.name < _._1.name } - .map { case (blockId, status) => - (blockId, status, blockLocations.getOrElse(blockId, Seq[String]("Unknown"))) - } - - val dataDistribution = if (includeDetails) { - Some(storageStatusList.map { status => - new RDDDataDistribution( - address = status.blockManagerId.hostPort, - memoryUsed = status.memUsedByRdd(rddId), - memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId), - onHeapMemoryUsed = Some( - if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), - offHeapMemoryUsed = Some( - if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), - onHeapMemoryRemaining = status.onHeapMemRemaining, - offHeapMemoryRemaining = status.offHeapMemRemaining - ) } ) - } else { - None - } - val partitions = if (includeDetails) { - Some(blocks.map { case (id, block, locations) => - new RDDPartitionInfo( - blockName = id.name, - storageLevel = block.storageLevel.description, - memoryUsed = block.memSize, - diskUsed = block.diskSize, - executors = locations - ) - } ) - } else { - None - } - - new RDDStorageInfo( - id = rddId, - name = rddInfo.name, - numPartitions = rddInfo.numPartitions, - numCachedPartitions = rddInfo.numCachedPartitions, - storageLevel = rddInfo.storageLevel.description, - memoryUsed = rddInfo.memSize, - diskUsed = rddInfo.diskSize, - dataDistribution = dataDistribution, - partitions = partitions - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 237aeac185877..ca9758cf0d109 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.NoSuchElementException import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType @@ -26,9 +27,12 @@ private[v1] class OneRDDResource(ui: SparkUI) { @GET def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { - AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( - throw new NotFoundException(s"no rdd found w/ id $rddId") - ) + try { + ui.store.rdd(rddId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"no rdd found w/ id $rddId") + } } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 23e9a360ddc02..f44b7935bfaa3 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -131,3 +131,19 @@ private[spark] class ExecutorStageSummaryWrapper( private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) } + +private[spark] class StreamBlockData( + val name: String, + val executorId: String, + val hostPort: String, + val storageLevel: String, + val useMemory: Boolean, + val useDisk: Boolean, + val deserialized: Boolean, + val memSize: Long, + val diskSize: Long) { + + @JsonIgnore @KVIndex + def key: Array[String] = Array(name, executorId) + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala deleted file mode 100644 index 0a14fcadf53e0..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.collection.mutable - -import org.apache.spark.scheduler._ - -private[spark] case class BlockUIData( - blockId: BlockId, - location: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long) - -/** - * The aggregated status of stream blocks in an executor - */ -private[spark] case class ExecutorStreamBlockStatus( - executorId: String, - location: String, - blocks: Seq[BlockUIData]) { - - def totalMemSize: Long = blocks.map(_.memSize).sum - - def totalDiskSize: Long = blocks.map(_.diskSize).sum - - def numStreamBlocks: Int = blocks.size - -} - -private[spark] class BlockStatusListener extends SparkListener { - - private val blockManagers = - new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] - - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - val blockId = blockUpdated.blockUpdatedInfo.blockId - if (!blockId.isInstanceOf[StreamBlockId]) { - // Now we only monitor StreamBlocks - return - } - val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - - synchronized { - // Drop the update info if the block manager is not registered - blockManagers.get(blockManagerId).foreach { blocksInBlockManager => - if (storageLevel.isValid) { - blocksInBlockManager.put(blockId, - BlockUIData( - blockId, - blockManagerId.hostPort, - storageLevel, - memSize, - diskSize) - ) - } else { - // If isValid is not true, it means we should drop the block. - blocksInBlockManager -= blockId - } - } - } - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - synchronized { - blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) - } - } - - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { - blockManagers -= blockManagerRemoved.blockManagerId - } - - def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { - blockManagers.map { case (blockManagerId, blocks) => - ExecutorStreamBlockStatus( - blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) - }.toSeq - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala deleted file mode 100644 index ac60f795915a3..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import scala.collection.mutable - -import org.apache.spark.SparkConf -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ - -/** - * :: DeveloperApi :: - * A SparkListener that maintains executor storage status. - * - * This class is thread-safe (unlike JobProgressListener) - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class StorageStatusListener(conf: SparkConf) extends SparkListener { - // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) - private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - private[storage] val deadExecutorStorageStatus = new mutable.ListBuffer[StorageStatus]() - private[this] val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) - - def storageStatusList: Seq[StorageStatus] = synchronized { - executorIdToStorageStatus.values.toSeq - } - - def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus - } - - /** Update storage status list to reflect updated block statuses */ - private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { - executorIdToStorageStatus.get(execId).foreach { storageStatus => - updatedBlocks.foreach { case (blockId, updatedStatus) => - if (updatedStatus.storageLevel == StorageLevel.NONE) { - storageStatus.removeBlock(blockId) - } else { - storageStatus.updateBlock(blockId, updatedStatus) - } - } - } - } - - /** Update storage status list to reflect the removal of an RDD from the cache */ - private def updateStorageStatus(unpersistedRDDId: Int) { - storageStatusList.foreach { storageStatus => - storageStatus.rddBlocksById(unpersistedRDDId).foreach { case (blockId, _) => - storageStatus.removeBlock(blockId) - } - } - } - - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { - updateStorageStatus(unpersistRDD.rddId) - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { - synchronized { - val blockManagerId = blockManagerAdded.blockManagerId - val executorId = blockManagerId.executorId - // The onHeap and offHeap memory are always defined for new applications, - // but they can be missing if we are replaying old event logs. - val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, - blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) - executorIdToStorageStatus(executorId) = storageStatus - - // Try to remove the dead storage status if same executor register the block manager twice. - deadExecutorStorageStatus.zipWithIndex.find(_._1.blockManagerId.executorId == executorId) - .foreach(toRemoveExecutor => deadExecutorStorageStatus.remove(toRemoveExecutor._2)) - } - } - - override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { - synchronized { - val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToStorageStatus.remove(executorId).foreach { status => - deadExecutorStorageStatus += status - } - if (deadExecutorStorageStatus.size > retainedDeadExecutors) { - deadExecutorStorageStatus.trimStart(1) - } - } - } - - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId - val blockId = blockUpdated.blockUpdatedInfo.blockId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val blockStatus = BlockStatus(storageLevel, memSize, diskSize) - updateStorageStatus(executorId, Seq((blockId, blockStatus))) - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 79d40b6a90c3c..e93ade001c607 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -26,13 +26,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1._ -import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.ExecutorsTab import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener -import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.storage.StorageTab import org.apache.spark.util.Utils /** @@ -43,9 +42,7 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val storageStatusListener: StorageStatusListener, val jobProgressListener: JobProgressListener, - val storageListener: StorageListener, val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String, @@ -69,7 +66,7 @@ private[spark] class SparkUI private ( attachTab(jobsTab) val stagesTab = new StagesTab(this, store) attachTab(stagesTab) - attachTab(new StorageTab(this)) + attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) @@ -186,18 +183,12 @@ private[spark] object SparkUI { addListenerFn(listener) listener } - - val storageStatusListener = new StorageStatusListener(conf) - val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - addListenerFn(storageStatusListener) - addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(store, sc, conf, securityManager, storageStatusListener, jobProgressListener, - storageListener, operationGraphListener, appName, basePath, lastUpdateTime, startTime, - appSparkVersion) + new SparkUI(store, sc, conf, securityManager, jobProgressListener, operationGraphListener, + appName, basePath, lastUpdateTime, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index e8ff08f7d88ff..02cee7f8c5b33 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -22,13 +22,13 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} -import org.apache.spark.status.api.v1.{AllRDDResource, RDDDataDistribution, RDDPartitionInfo} -import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils, WebUIPage} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.ui._ import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ -private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { - private val listener = parent.listener +private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("rdd") { def render(request: HttpServletRequest): Seq[Node] = { // stripXSS is called first to remove suspicious characters used in XSS attacks @@ -48,11 +48,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt - val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) - .getOrElse { + val rddStorageInfo = try { + store.rdd(rddId) + } catch { + case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) - } + } // Worker table val workerTable = UIUtils.listingTable(workerHeader, workerRow, diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b6c764d1728e4..b8aec9890247a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -19,23 +19,23 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest +import scala.collection.SortedMap import scala.xml.Node -import org.apache.spark.storage._ -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.status.{AppStatusStore, StreamBlockData} +import org.apache.spark.status.api.v1 +import org.apache.spark.ui._ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ -private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - private val listener = parent.listener +private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(listener.rddInfoList) ++ - receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) + val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) UIUtils.headerSparkPage("Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -58,7 +58,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: RDDInfo): Seq[Node] = { + private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} @@ -67,35 +67,40 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.name} - {rdd.storageLevel.description} + {rdd.storageLevel} {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.diskSize)} + {Utils.bytesToString(rdd.memoryUsed)} + {Utils.bytesToString(rdd.diskUsed)} // scalastyle:on } - private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { - if (statuses.map(_.numStreamBlocks).sum == 0) { + private[storage] def receiverBlockTables(blocks: Seq[StreamBlockData]): Seq[Node] = { + if (blocks.isEmpty) { // Don't show the tables if there is no stream block Nil } else { - val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + val sorted = blocks.groupBy(_.name).toSeq.sortBy(_._1.toString)

    Receiver Blocks

    - {executorMetricsTable(statuses)} - {streamBlockTable(blocks)} + {executorMetricsTable(blocks)} + {streamBlockTable(sorted)}
    } } - private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + private def executorMetricsTable(blocks: Seq[StreamBlockData]): Seq[Node] = { + val blockManagers = SortedMap(blocks.groupBy(_.executorId).toSeq: _*) + .map { case (id, blocks) => + new ExecutorStreamSummary(blocks) + } +
    Aggregated Block Metrics by Executor
    - {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, blockManagers, id = Some("storage-by-executor-stream-blocks"))}
    } @@ -107,7 +112,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Total Size on Disk", "Stream Blocks") - private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + private def executorMetricsTableRow(status: ExecutorStreamSummary): Seq[Node] = { {status.executorId} @@ -127,7 +132,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } - private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + private def streamBlockTable(blocks: Seq[(String, Seq[StreamBlockData])]): Seq[Node] = { if (blocks.isEmpty) { Nil } else { @@ -151,7 +156,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Size") /** Render a stream block */ - private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + private def streamBlockTableRow(block: (String, Seq[StreamBlockData])): Seq[Node] = { val replications = block._2 assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { @@ -163,33 +168,36 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } private def streamBlockTableSubrow( - blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + blockId: String, + block: StreamBlockData, + replication: Int, + firstSubrow: Boolean): Seq[Node] = { val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) { if (firstSubrow) { - {block.blockId.toString} + {block.name} {replication.toString} } } - {block.location} + {block.hostPort} {storageLevel} {Utils.bytesToString(size)} } private[storage] def streamBlockStorageLevelDescriptionAndSize( - block: BlockUIData): (String, Long) = { - if (block.storageLevel.useDisk) { + block: StreamBlockData): (String, Long) = { + if (block.useDisk) { ("Disk", block.diskSize) - } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + } else if (block.useMemory && block.deserialized) { ("Memory", block.memSize) - } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + } else if (block.useMemory && !block.deserialized) { ("Memory Serialized", block.memSize) } else { throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") @@ -197,3 +205,17 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { } } + +private class ExecutorStreamSummary(blocks: Seq[StreamBlockData]) { + + def executorId: String = blocks.head.executorId + + def location: String = blocks.head.hostPort + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def numStreamBlocks: Int = blocks.size + +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 148efb134e14f..688efa24ade0c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -17,71 +17,13 @@ package org.apache.spark.ui.storage -import scala.collection.mutable - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.scheduler._ -import org.apache.spark.storage._ +import org.apache.spark.status.AppStatusStore import org.apache.spark.ui._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { - val listener = parent.storageListener - - attachPage(new StoragePage(this)) - attachPage(new RDDPage(this)) -} - -/** - * :: DeveloperApi :: - * A SparkListener that prepares information to be displayed on the BlockManagerUI. - * - * This class is thread-safe (unlike JobProgressListener) - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { - - private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - - def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - - /** Filter RDD info to include only those with cached partitions */ - def rddInfoList: Seq[RDDInfo] = synchronized { - _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq - } - - /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ - private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { - val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet - val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) } - StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList) - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val rddInfos = stageSubmitted.stageInfo.rddInfos - rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info).name = info.name } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - // Remove all partitions that are no longer cached in current completed stage - val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet - _rddInfoMap.retain { case (id, info) => - !completedRddIds.contains(id) || info.numCachedPartitions > 0 - } - } - - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { - _rddInfoMap.remove(unpersistRDD.rddId) - } +private[ui] class StorageTab(parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "storage") { - override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { - super.onBlockUpdated(blockUpdated) - val blockId = blockUpdated.blockUpdatedInfo.blockId - val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel - val memSize = blockUpdated.blockUpdatedInfo.memSize - val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val blockStatus = BlockStatus(storageLevel, memSize, diskSize) - updateRDDInfo(Seq((blockId, blockStatus))) - } + attachPage(new StoragePage(this, store)) + attachPage(new RDDPage(this, store)) } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 867d35f231dc0..ba082bc93dd42 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -578,145 +578,160 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - val rdd1b1 = RDDBlockId(1, 1) + val rdd1b1 = RddBlock(1, 1, 1L, 2L) + val rdd1b2 = RddBlock(1, 2, 3L, 4L) + val rdd2b1 = RddBlock(2, 1, 5L, 6L) val level = StorageLevel.MEMORY_AND_DISK - // Submit a stage and make sure the RDD is recorded. - val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) - val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + // Submit a stage and make sure the RDDs are recorded. + val rdd1Info = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val rdd2Info = new RDDInfo(rdd2b1.rddId, "rdd2", 1, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info, rdd2Info), Nil, "details1") listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.name === rddInfo.name) - assert(wrapper.info.numPartitions === rddInfo.numPartitions) - assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + assert(wrapper.info.name === rdd1Info.name) + assert(wrapper.info.numPartitions === rdd1Info.numPartitions) + assert(wrapper.info.storageLevel === rdd1Info.storageLevel.description) } // Add partition 1 replicated on two block managers. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 1L) - assert(wrapper.info.diskUsed === 1L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize) + assert(wrapper.info.diskUsed === rdd1b1.diskSize) assert(wrapper.info.dataDistribution.isDefined) assert(wrapper.info.dataDistribution.get.size === 1) val dist = wrapper.info.dataDistribution.get.head assert(dist.address === bm1.hostPort) - assert(dist.memoryUsed === 1L) - assert(dist.diskUsed === 1L) + assert(dist.memoryUsed === rdd1b1.memSize) + assert(dist.diskUsed === rdd1b1.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) assert(wrapper.info.partitions.isDefined) assert(wrapper.info.partitions.get.size === 1) val part = wrapper.info.partitions.get.head - assert(part.blockName === rdd1b1.name) + assert(part.blockName === rdd1b1.blockId.name) assert(part.storageLevel === level.description) - assert(part.memoryUsed === 1L) - assert(part.diskUsed === 1L) + assert(part.memoryUsed === rdd1b1.memSize) + assert(part.diskUsed === rdd1b1.diskSize) assert(part.executors === Seq(bm1.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 1L) - assert(exec.info.diskUsed === 1L) + assert(exec.info.memoryUsed === rdd1b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize) } - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm2, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 2L) - assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize * 2) + assert(wrapper.info.diskUsed === rdd1b1.diskSize * 2) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 1L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get - assert(dist.memoryUsed === 1L) - assert(dist.diskUsed === 1L) + assert(dist.memoryUsed === rdd1b1.memSize) + assert(dist.diskUsed === rdd1b1.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) val part = wrapper.info.partitions.get(0) - assert(part.memoryUsed === 2L) - assert(part.diskUsed === 2L) + assert(part.memoryUsed === rdd1b1.memSize * 2) + assert(part.diskUsed === rdd1b1.diskSize * 2) assert(part.executors === Seq(bm1.executorId, bm2.executorId)) } check[ExecutorSummaryWrapper](bm2.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 1L) - assert(exec.info.diskUsed === 1L) + assert(exec.info.memoryUsed === rdd1b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize) } // Add a second partition only to bm 1. - val rdd1b2 = RDDBlockId(1, 2) - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, - 3L, 3L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b2.blockId, level, rdd1b2.memSize, rdd1b2.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 5L) - assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.numCachedPartitions === 2L) + assert(wrapper.info.memoryUsed === 2 * rdd1b1.memSize + rdd1b2.memSize) + assert(wrapper.info.diskUsed === 2 * rdd1b1.diskSize + rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 2L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get - assert(dist.memoryUsed === 4L) - assert(dist.diskUsed === 4L) + assert(dist.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(dist.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) - val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.blockId.name).get assert(part.storageLevel === level.description) - assert(part.memoryUsed === 3L) - assert(part.diskUsed === 3L) + assert(part.memoryUsed === rdd1b2.memSize) + assert(part.diskUsed === rdd1b2.diskSize) assert(part.executors === Seq(bm1.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 2L) - assert(exec.info.memoryUsed === 4L) - assert(exec.info.diskUsed === 4L) + assert(exec.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) } // Remove block 1 from bm 1. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, - StorageLevel.NONE, 1L, 1L))) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, StorageLevel.NONE, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 4L) - assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.numCachedPartitions === 2L) + assert(wrapper.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize) + assert(wrapper.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 2L) assert(wrapper.info.partitions.get.size === 2L) val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get - assert(dist.memoryUsed === 3L) - assert(dist.diskUsed === 3L) + assert(dist.memoryUsed === rdd1b2.memSize) + assert(dist.diskUsed === rdd1b2.diskSize) assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) - val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.blockId.name).get assert(part.storageLevel === level.description) - assert(part.memoryUsed === 1L) - assert(part.diskUsed === 1L) + assert(part.memoryUsed === rdd1b1.memSize) + assert(part.diskUsed === rdd1b1.diskSize) assert(part.executors === Seq(bm2.executorId)) } check[ExecutorSummaryWrapper](bm1.executorId) { exec => assert(exec.info.rddBlocks === 1L) - assert(exec.info.memoryUsed === 3L) - assert(exec.info.diskUsed === 3L) + assert(exec.info.memoryUsed === rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize) } - // Remove block 2 from bm 2. This should leave only block 2 info in the store. - listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, - StorageLevel.NONE, 1L, 1L))) + // Remove block 1 from bm 2. This should leave only block 2's info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm2, rdd1b1.blockId, StorageLevel.NONE, rdd1b1.memSize, rdd1b1.diskSize))) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => - assert(wrapper.info.memoryUsed === 3L) - assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.numCachedPartitions === 1L) + assert(wrapper.info.memoryUsed === rdd1b2.memSize) + assert(wrapper.info.diskUsed === rdd1b2.diskSize) assert(wrapper.info.dataDistribution.get.size === 1L) assert(wrapper.info.partitions.get.size === 1L) - assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.blockId.name) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === rdd1b2.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize) } check[ExecutorSummaryWrapper](bm2.executorId) { exec => @@ -725,12 +740,61 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.diskUsed === 0L) } + // Add a block from a different RDD. Verify the executor is updated correctly and also that + // the distribution data for both rdds is updated to match the remaining memory. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd2b1.blockId, level, rdd2b1.memSize, rdd2b1.diskSize))) + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === rdd1b2.memSize + rdd2b1.memSize) + assert(exec.info.diskUsed === rdd1b2.diskSize + rdd2b1.diskSize) + } + + check[RDDStorageInfoWrapper](rdd1b2.rddId) { wrapper => + assert(wrapper.info.dataDistribution.get.size === 1L) + val dist = wrapper.info.dataDistribution.get(0) + assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) + } + + check[RDDStorageInfoWrapper](rdd2b1.rddId) { wrapper => + assert(wrapper.info.dataDistribution.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get(0) + assert(dist.memoryUsed === rdd2b1.memSize) + assert(dist.diskUsed === rdd2b1.diskSize) + assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) + } + // Unpersist RDD1. listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) - intercept[NoSuchElementException] { + intercept[NoSuchElementException] { check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } } + // Update a StreamBlock. + val stream1 = StreamBlockId(1, 1L) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, stream1, level, 1L, 1L))) + + check[StreamBlockData](Array(stream1.name, bm1.executorId)) { stream => + assert(stream.name === stream1.name) + assert(stream.executorId === bm1.executorId) + assert(stream.hostPort === bm1.hostPort) + assert(stream.storageLevel === level.description) + assert(stream.useMemory === level.useMemory) + assert(stream.useDisk === level.useDisk) + assert(stream.deserialized === level.deserialized) + assert(stream.memSize === 1L) + assert(stream.diskSize === 1L) + } + + // Drop a StreamBlock. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, stream1, StorageLevel.NONE, 0L, 0L))) + intercept[NoSuchElementException] { + check[StreamBlockData](stream1.name) { _ => () } + } } private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) @@ -740,4 +804,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { fn(value) } + private case class RddBlock( + rddId: Int, + partId: Int, + memSize: Long, + diskSize: Long) { + + def blockId: BlockId = RDDBlockId(rddId, partId) + + } + } diff --git a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala new file mode 100644 index 0000000000000..bb2d2633001f0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.RDDPartitionInfo + +class LiveEntitySuite extends SparkFunSuite { + + test("partition seq") { + val seq = new RDDPartitionSeq() + val items = (1 to 10).map { i => + val part = newPartition(i) + seq.addPartition(part) + part + }.toList + + checkSize(seq, 10) + + val added = newPartition(11) + seq.addPartition(added) + checkSize(seq, 11) + assert(seq.last.blockName === added.blockName) + + seq.removePartition(items(0)) + assert(seq.head.blockName === items(1).blockName) + assert(!seq.exists(_.blockName == items(0).blockName)) + checkSize(seq, 10) + + seq.removePartition(added) + assert(seq.last.blockName === items.last.blockName) + assert(!seq.exists(_.blockName == added.blockName)) + checkSize(seq, 9) + + seq.removePartition(items(5)) + checkSize(seq, 8) + assert(!seq.exists(_.blockName == items(5).blockName)) + } + + private def checkSize(seq: Seq[_], expected: Int): Unit = { + assert(seq.length === expected) + var count = 0 + seq.iterator.foreach { _ => count += 1 } + assert(count === expected) + } + + private def newPartition(i: Int): LiveRDDPartition = { + val part = new LiveRDDPartition(i.toString) + part.update(Seq(i.toString), i.toString, i, i) + part + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala deleted file mode 100644 index 06acca3943c20..0000000000000 --- a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import org.apache.spark.SparkFunSuite -import org.apache.spark.scheduler._ - -class BlockStatusListenerSuite extends SparkFunSuite { - - test("basic functions") { - val blockManagerId = BlockManagerId("0", "localhost", 10000) - val listener = new BlockStatusListener() - - // Add a block manager and a new block status - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - // The new block status should be added to the listener - val expectedBlock = BlockUIData( - StreamBlockId(0, 100), - "localhost:10000", - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100 - ) - val expectedExecutorStreamBlockStatus = Seq( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) - ) - assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) - - // Add the second block manager - val blockManagerId2 = BlockManagerId("1", "localhost", 10001) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) - // Add a new replication of the same block id from the second manager - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - val expectedBlock2 = BlockUIData( - StreamBlockId(0, 100), - "localhost:10001", - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100 - ) - // Each block manager should contain one block - val expectedExecutorStreamBlockStatus2 = Set( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), - ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) - ) - assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) - - // Remove a replication of the same block - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.NONE, // StorageLevel.NONE means removing it - memSize = 0, - diskSize = 0))) - // Only the first block manager contains a block - val expectedExecutorStreamBlockStatus3 = Set( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), - ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) - ) - assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) - - // Remove the second block manager at first but add a new block status - // from this removed block manager - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) - listener.onBlockUpdated(SparkListenerBlockUpdated( - BlockUpdatedInfo( - blockManagerId2, - StreamBlockId(0, 100), - StorageLevel.MEMORY_AND_DISK, - memSize = 100, - diskSize = 100))) - // The second block manager is removed so we should not see the new block - val expectedExecutorStreamBlockStatus4 = Seq( - ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) - ) - assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) - - // Remove the last block manager - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) - // No block manager now so we should dop all block managers - assert(listener.allExecutorStreamBlockStatus.isEmpty) - } - -} diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala deleted file mode 100644 index 9835f11a2f7ed..0000000000000 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import org.apache.spark.{SparkConf, SparkFunSuite, Success} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler._ - -/** - * Test the behavior of StorageStatusListener in response to all relevant events. - */ -class StorageStatusListenerSuite extends SparkFunSuite { - private val bm1 = BlockManagerId("big", "dog", 1) - private val bm2 = BlockManagerId("fat", "duck", 2) - private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) - private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false) - private val conf = new SparkConf() - - test("block manager added/removed") { - conf.set("spark.ui.retainedDeadExecutors", "1") - val listener = new StorageStatusListener(conf) - - // Block manager add - assert(listener.executorIdToStorageStatus.size === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - assert(listener.executorIdToStorageStatus.size === 1) - assert(listener.executorIdToStorageStatus.get("big").isDefined) - assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1) - assert(listener.executorIdToStorageStatus("big").maxMem === 1000L) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - assert(listener.executorIdToStorageStatus.size === 2) - assert(listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2) - assert(listener.executorIdToStorageStatus("fat").maxMem === 2000L) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - - // Block manager remove - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm1)) - assert(listener.executorIdToStorageStatus.size === 1) - assert(!listener.executorIdToStorageStatus.get("big").isDefined) - assert(listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.deadExecutorStorageStatus.size === 1) - assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("big")) - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) - assert(listener.executorIdToStorageStatus.size === 0) - assert(!listener.executorIdToStorageStatus.get("big").isDefined) - assert(!listener.executorIdToStorageStatus.get("fat").isDefined) - assert(listener.deadExecutorStorageStatus.size === 1) - assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("fat")) - } - - test("task end without updated blocks") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - val taskMetrics = new TaskMetrics - - // Task end with no updated blocks - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - } - - test("updated blocks") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - - val blockUpdateInfos1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) - ) - val blockUpdateInfos2 = - Seq(BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) - - // Add some new blocks - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - postUpdateBlock(listener, blockUpdateInfos1) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - postUpdateBlock(listener, blockUpdateInfos2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - - // Dropped the blocks - val droppedBlockInfo1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.NONE, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) - ) - val droppedBlockInfo2 = Seq( - BlockUpdatedInfo(bm2, RDDBlockId(1, 2), StorageLevel.NONE, 0L, 0L), - BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) - ) - - postUpdateBlock(listener, droppedBlockInfo1) - assert(listener.executorIdToStorageStatus("big").numBlocks === 1) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) - assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - postUpdateBlock(listener, droppedBlockInfo2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 1) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - } - - test("unpersist RDD") { - val listener = new StorageStatusListener(conf) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - val blockUpdateInfos1 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) - ) - val blockUpdateInfos2 = - Seq(BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) - postUpdateBlock(listener, blockUpdateInfos1) - postUpdateBlock(listener, blockUpdateInfos2) - assert(listener.executorIdToStorageStatus("big").numBlocks === 3) - - // Unpersist RDD - listener.onUnpersistRDD(SparkListenerUnpersistRDD(9090)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 3) - listener.onUnpersistRDD(SparkListenerUnpersistRDD(4)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 2) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) - listener.onUnpersistRDD(SparkListenerUnpersistRDD(1)) - assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - } - - private def postUpdateBlock( - listener: StorageStatusListener, updateBlockInfos: Seq[BlockUpdatedInfo]): Unit = { - updateBlockInfos.foreach { updateBlockInfo => - listener.onBlockUpdated(SparkListenerBlockUpdated(updateBlockInfo)) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 4a48b3c686725..a71521c91d2f2 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -20,39 +20,46 @@ package org.apache.spark.ui.storage import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite +import org.apache.spark.status.StreamBlockData +import org.apache.spark.status.api.v1.RDDStorageInfo import org.apache.spark.storage._ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") - val storagePage = new StoragePage(storageTab) + val storagePage = new StoragePage(storageTab, null) test("rddTable") { - val rdd1 = new RDDInfo(1, + val rdd1 = new RDDStorageInfo(1, "rdd1", 10, - StorageLevel.MEMORY_ONLY, - Seq.empty) - rdd1.memSize = 100 - rdd1.numCachedPartitions = 10 + 10, + StorageLevel.MEMORY_ONLY.description, + 100L, + 0L, + None, + None) - val rdd2 = new RDDInfo(2, + val rdd2 = new RDDStorageInfo(2, "rdd2", 10, - StorageLevel.DISK_ONLY, - Seq.empty) - rdd2.diskSize = 200 - rdd2.numCachedPartitions = 5 - - val rdd3 = new RDDInfo(3, + 5, + StorageLevel.DISK_ONLY.description, + 0L, + 200L, + None, + None) + + val rdd3 = new RDDStorageInfo(3, "rdd3", 10, - StorageLevel.MEMORY_AND_DISK_SER, - Seq.empty) - rdd3.memSize = 400 - rdd3.diskSize = 500 - rdd3.numCachedPartitions = 10 + 10, + StorageLevel.MEMORY_AND_DISK_SER.description, + 400L, + 500L, + None, + None) val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) @@ -91,58 +98,85 @@ class StoragePageSuite extends SparkFunSuite { } test("streamBlockStorageLevelDescriptionAndSize") { - val memoryBlock = BlockUIData(StreamBlockId(0, 0), + val memoryBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.MEMORY_ONLY, - memSize = 100, - diskSize = 0) + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, + 100, + 0) assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) - val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + val memorySerializedBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.MEMORY_ONLY_SER, + StorageLevel.MEMORY_ONLY_SER.description, + true, + false, + false, memSize = 100, diskSize = 0) assert(("Memory Serialized", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) - val diskBlock = BlockUIData(StreamBlockId(0, 0), + val diskBlock = new StreamBlockData("0", + "0", "localhost:1111", - StorageLevel.DISK_ONLY, - memSize = 0, - diskSize = 100) + StorageLevel.DISK_ONLY.description, + false, + true, + false, + 0, + 100) assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) } test("receiverBlockTables") { val blocksForExecutor0 = Seq( - BlockUIData(StreamBlockId(0, 0), + new StreamBlockData(StreamBlockId(0, 0).name, + "0", "localhost:10000", - StorageLevel.MEMORY_ONLY, - memSize = 100, - diskSize = 0), - BlockUIData(StreamBlockId(1, 1), + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, + 100, + 0), + new StreamBlockData(StreamBlockId(1, 1).name, + "0", "localhost:10000", - StorageLevel.DISK_ONLY, - memSize = 0, - diskSize = 100) + StorageLevel.DISK_ONLY.description, + false, + true, + false, + 0, + 100) ) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) val blocksForExecutor1 = Seq( - BlockUIData(StreamBlockId(0, 0), + new StreamBlockData(StreamBlockId(0, 0).name, + "1", "localhost:10001", - StorageLevel.MEMORY_ONLY, + StorageLevel.MEMORY_ONLY.description, + true, + false, + true, memSize = 100, diskSize = 0), - BlockUIData(StreamBlockId(1, 1), + new StreamBlockData(StreamBlockId(1, 1).name, + "1", "localhost:10001", - StorageLevel.MEMORY_ONLY_SER, - memSize = 100, - diskSize = 0) + StorageLevel.MEMORY_ONLY_SER.description, + true, + false, + false, + 100, + 0) ) - val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) - val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val xmlNodes = storagePage.receiverBlockTables(blocksForExecutor0 ++ blocksForExecutor1) val executorTable = (xmlNodes \\ "table")(0) val executorHeaders = Seq( @@ -190,8 +224,6 @@ class StoragePageSuite extends SparkFunSuite { test("empty receiverBlockTables") { assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) - val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) - assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) } + } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala deleted file mode 100644 index 79f02f2e50bbd..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.storage - -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.storage._ - -/** - * Test various functionality in the StorageListener that supports the StorageTab. - */ -class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { - private var bus: SparkListenerBus = _ - private var storageStatusListener: StorageStatusListener = _ - private var storageListener: StorageListener = _ - private val memAndDisk = StorageLevel.MEMORY_AND_DISK - private val memOnly = StorageLevel.MEMORY_ONLY - private val none = StorageLevel.NONE - private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) - private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false) - private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly, Seq(10)) - private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly, Seq(10)) - private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk, Seq(10)) - private def rddInfo3 = new RDDInfo(3, "grace", 400, memAndDisk, Seq(10)) - private val bm1 = BlockManagerId("big", "dog", 1) - - before { - val conf = new SparkConf() - bus = new ReplayListenerBus() - storageStatusListener = new StorageStatusListener(conf) - storageListener = new StorageListener(storageStatusListener) - bus.addListener(storageStatusListener) - bus.addListener(storageListener) - } - - test("stage submitted / completed") { - assert(storageListener._rddInfoMap.isEmpty) - assert(storageListener.rddInfoList.isEmpty) - - // 2 RDDs are known, but none are cached - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0, rddInfo1), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.isEmpty) - - // 4 RDDs are known, but only 2 are cached - val rddInfo2Cached = rddInfo2 - val rddInfo3Cached = rddInfo3 - rddInfo2Cached.numCachedPartitions = 1 - rddInfo3Cached.numCachedPartitions = 1 - val stageInfo1 = new StageInfo( - 1, 0, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener._rddInfoMap.size === 4) - assert(storageListener.rddInfoList.size === 2) - - // Submitting RDDInfos with duplicate IDs does nothing - val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY, Seq(10)) - rddInfo0Cached.numCachedPartitions = 1 - val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) - assert(storageListener._rddInfoMap.size === 4) - assert(storageListener.rddInfoList.size === 2) - - // We only keep around the RDDs that are cached - bus.postToAll(SparkListenerStageCompleted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.size === 2) - } - - test("unpersist") { - val rddInfo0Cached = rddInfo0 - val rddInfo1Cached = rddInfo1 - rddInfo0Cached.numCachedPartitions = 1 - rddInfo1Cached.numCachedPartitions = 1 - val stageInfo0 = new StageInfo( - 0, 0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 2) - assert(storageListener.rddInfoList.size === 2) - bus.postToAll(SparkListenerUnpersistRDD(0)) - assert(storageListener._rddInfoMap.size === 1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerUnpersistRDD(4)) // doesn't exist - assert(storageListener._rddInfoMap.size === 1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerUnpersistRDD(1)) - assert(storageListener._rddInfoMap.size === 0) - assert(storageListener.rddInfoList.size === 0) - } - - test("block update") { - val myRddInfo0 = rddInfo0 - val myRddInfo1 = rddInfo1 - val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo( - 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), Seq.empty, "details") - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener._rddInfoMap.size === 3) - assert(storageListener.rddInfoList.size === 0) // not cached - assert(!storageListener._rddInfoMap(0).isCached) - assert(!storageListener._rddInfoMap(1).isCached) - assert(!storageListener._rddInfoMap(2).isCached) - - // Some blocks updated - val blockUpdateInfos = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(0, 100), memAndDisk, 400L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(0, 101), memAndDisk, 0L, 400L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 20), memAndDisk, 0L, 240L) - ) - postUpdateBlocks(bus, blockUpdateInfos) - assert(storageListener._rddInfoMap(0).memSize === 400L) - assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) - assert(storageListener._rddInfoMap(0).isCached) - assert(storageListener._rddInfoMap(1).memSize === 0L) - assert(storageListener._rddInfoMap(1).diskSize === 240L) - assert(storageListener._rddInfoMap(1).numCachedPartitions === 1) - assert(storageListener._rddInfoMap(1).isCached) - assert(!storageListener._rddInfoMap(2).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - - // Drop some blocks - val blockUpdateInfos2 = Seq( - BlockUpdatedInfo(bm1, RDDBlockId(0, 100), none, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(1, 20), none, 0L, 0L), - BlockUpdatedInfo(bm1, RDDBlockId(2, 40), none, 0L, 0L), // doesn't actually exist - BlockUpdatedInfo(bm1, RDDBlockId(4, 80), none, 0L, 0L) // doesn't actually exist - ) - postUpdateBlocks(bus, blockUpdateInfos2) - assert(storageListener._rddInfoMap(0).memSize === 0L) - assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 1) - assert(storageListener._rddInfoMap(0).isCached) - assert(!storageListener._rddInfoMap(1).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - assert(!storageListener._rddInfoMap(2).isCached) - assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - } - - test("verify StorageTab contains all cached rdds") { - - val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly, Seq(4)) - val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4)) - val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details") - val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") - val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) - val blockUpdateInfos2 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(1, 1), memOnly, 200L, 0L)) - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - assert(storageListener.rddInfoList.size === 0) - postUpdateBlocks(bus, blockUpdateInfos1) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerStageCompleted(stageInfo0)) - assert(storageListener.rddInfoList.size === 1) - postUpdateBlocks(bus, blockUpdateInfos2) - assert(storageListener.rddInfoList.size === 2) - bus.postToAll(SparkListenerStageCompleted(stageInfo1)) - assert(storageListener.rddInfoList.size === 2) - } - - test("verify StorageTab still contains a renamed RDD") { - val rddInfo = new RDDInfo(0, "original_name", 1, memOnly, Seq(4)) - val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo), Seq.empty, "details") - bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) - val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) - postUpdateBlocks(bus, blockUpdateInfos1) - assert(storageListener.rddInfoList.size == 1) - - val newName = "new_name" - val rddInfoRenamed = new RDDInfo(0, newName, 1, memOnly, Seq(4)) - val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfoRenamed), Seq.empty, "details") - bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) - assert(storageListener.rddInfoList.size == 1) - assert(storageListener.rddInfoList.head.name == newName) - } - - private def postUpdateBlocks( - bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = { - blockUpdateInfos.foreach { blockUpdateInfo => - bus.postToAll(SparkListenerBlockUpdated(blockUpdateInfo)) - } - } -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0c31b2b4a9402..090375f65552e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,6 +41,8 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 9eb7096c47a7e5f98de19512515386b3d0698039 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 9 Nov 2017 16:05:47 -0800 Subject: [PATCH 212/712] [SPARK-22483][CORE] Exposing java.nio bufferedPool memory metrics to Metric System ## What changes were proposed in this pull request? Adds java.nio bufferedPool memory metrics to metrics system which includes both direct and mapped memory. ## How was this patch tested? Manually tested and checked direct and mapped memory metrics too available in metrics system using Console sink. Here is the sample console output application_1509655862825_0016.2.jvm.direct.capacity value = 19497 application_1509655862825_0016.2.jvm.direct.count value = 6 application_1509655862825_0016.2.jvm.direct.used value = 19498 application_1509655862825_0016.2.jvm.mapped.capacity value = 0 application_1509655862825_0016.2.jvm.mapped.count value = 0 application_1509655862825_0016.2.jvm.mapped.used value = 0 Author: Srinivasa Reddy Vundela Closes #19709 from vundela/SPARK-22483. --- .../scala/org/apache/spark/metrics/source/JvmSource.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index 635bff2cd7ec8..dcaa0f19295e8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -17,8 +17,10 @@ package org.apache.spark.metrics.source +import java.lang.management.ManagementFactory + import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} +import com.codahale.metrics.jvm.{BufferPoolMetricSet, GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { override val sourceName = "jvm" @@ -26,4 +28,6 @@ private[spark] class JvmSource extends Source { metricRegistry.registerAll(new GarbageCollectorMetricSet) metricRegistry.registerAll(new MemoryUsageGaugeSet) + metricRegistry.registerAll( + new BufferPoolMetricSet(ManagementFactory.getPlatformMBeanServer)) } From 11c4021044f3a302449a2ea76811e73f5c99a26a Mon Sep 17 00:00:00 2001 From: Wing Yew Poon Date: Thu, 9 Nov 2017 16:20:55 -0800 Subject: [PATCH 213/712] [SPARK-22403][SS] Add optional checkpointLocation argument to StructuredKafkaWordCount example ## What changes were proposed in this pull request? When run in YARN cluster mode, the StructuredKafkaWordCount example fails because Spark tries to create a temporary checkpoint location in a subdirectory of the path given by java.io.tmpdir, and YARN sets java.io.tmpdir to a path in the local filesystem that usually does not correspond to an existing path in the distributed filesystem. Add an optional checkpointLocation argument to the StructuredKafkaWordCount example so that users can specify the checkpoint location and avoid this issue. ## How was this patch tested? Built and ran the example manually on YARN client and cluster mode. Author: Wing Yew Poon Closes #19703 from wypoon/SPARK-22403. --- .../sql/streaming/StructuredKafkaWordCount.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala index c26f73e788814..2aab49c8891d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala @@ -18,11 +18,14 @@ // scalastyle:off println package org.apache.spark.examples.sql.streaming +import java.util.UUID + import org.apache.spark.sql.SparkSession /** * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: StructuredKafkaWordCount + * [] * The Kafka "bootstrap.servers" configuration. A * comma-separated list of host:port. * There are three kinds of type, i.e. 'assign', 'subscribe', @@ -36,6 +39,8 @@ import org.apache.spark.sql.SparkSession * |- Only one of "assign, "subscribe" or "subscribePattern" options can be * | specified for Kafka source. * Different value format depends on the value of 'subscribe-type'. + * Directory in which to create checkpoints. If not + * provided, defaults to a randomized directory in /tmp. * * Example: * `$ bin/run-example \ @@ -46,11 +51,13 @@ object StructuredKafkaWordCount { def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println("Usage: StructuredKafkaWordCount " + - " ") + " []") System.exit(1) } - val Array(bootstrapServers, subscribeType, topics) = args + val Array(bootstrapServers, subscribeType, topics, _*) = args + val checkpointLocation = + if (args.length > 3) args(3) else "/tmp/temporary-" + UUID.randomUUID.toString val spark = SparkSession .builder @@ -76,6 +83,7 @@ object StructuredKafkaWordCount { val query = wordCounts.writeStream .outputMode("complete") .format("console") + .option("checkpointLocation", checkpointLocation) .start() query.awaitTermination() From 64c989495a4ac7e23a22672448a622862a63486a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2017 16:40:19 -0800 Subject: [PATCH 214/712] [SPARK-22485][BUILD] Use `exclude[Problem]` instead `excludePackage` in MiMa ## What changes were proposed in this pull request? `excludePackage` is deprecated like the [following](https://github.com/lightbend/migration-manager/blob/master/core/src/main/scala/com/typesafe/tools/mima/core/Filters.scala#L33-L36) and shows deprecation warnings now. This PR uses `exclude[Problem](packageName + ".*")` instead. ```scala deprecated("Replace with ProblemFilters.exclude[Problem](\"my.package.*\")", "0.1.15") def excludePackage(packageName: String): ProblemFilter = { exclude[Problem](packageName + ".*") } ``` ## How was this patch tested? Pass the Jenkins MiMa. Author: Dongjoon Hyun Closes #19710 from dongjoon-hyun/SPARK-22485. --- project/MimaBuild.scala | 4 ++-- project/MimaExcludes.scala | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index de0655b6cb357..2ef0e7b40d940 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -44,7 +44,7 @@ object MimaBuild { // Exclude a single class def excludeClass(className: String) = Seq( - excludePackage(className), + ProblemFilters.exclude[Problem](className + ".*"), ProblemFilters.exclude[MissingClassProblem](className), ProblemFilters.exclude[MissingTypesProblem](className) ) @@ -56,7 +56,7 @@ object MimaBuild { // Exclude a Spark package, that is in the package org.apache.spark def excludeSparkPackage(packageName: String) = { - excludePackage("org.apache.spark." + packageName) + ProblemFilters.exclude[Problem]("org.apache.spark." + packageName + ".*") } def ignoredABIProblems(base: File, currentSparkVersion: String) = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 090375f65552e..e6f136c7c8b0a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -242,17 +242,17 @@ object MimaExcludes { // Exclude rules for 2.0.x lazy val v20excludes = { Seq( - excludePackage("org.apache.spark.rpc"), - excludePackage("org.spark-project.jetty"), - excludePackage("org.spark_project.jetty"), - excludePackage("org.apache.spark.internal"), - excludePackage("org.apache.spark.unused"), - excludePackage("org.apache.spark.unsafe"), - excludePackage("org.apache.spark.memory"), - excludePackage("org.apache.spark.util.collection.unsafe"), - excludePackage("org.apache.spark.sql.catalyst"), - excludePackage("org.apache.spark.sql.execution"), - excludePackage("org.apache.spark.sql.internal"), + ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), + ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), + ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), ProblemFilters.exclude[MissingMethodProblem]( From f5fe63f7b8546b0102d7bfaf3dde77379f58a4d1 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Thu, 9 Nov 2017 16:42:33 -0800 Subject: [PATCH 215/712] =?UTF-8?q?[SPARK-22287][MESOS]=20SPARK=5FDAEMON?= =?UTF-8?q?=5FMEMORY=20not=20honored=20by=20MesosClusterD=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ispatcher ## What changes were proposed in this pull request? Allow JVM max heap size to be controlled for MesosClusterDispatcher via SPARK_DAEMON_MEMORY environment variable. ## How was this patch tested? Tested on local Mesos cluster Author: Paul Mackles Closes #19515 from pmackles/SPARK-22287. --- .../java/org/apache/spark/launcher/SparkClassCommandBuilder.java | 1 + 1 file changed, 1 insertion(+) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 32724acdc362c..fd056bb90e0c4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -81,6 +81,7 @@ public List buildCommand(Map env) case "org.apache.spark.deploy.mesos.MesosClusterDispatcher": javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); extraClassPath = getenv("SPARK_DAEMON_CLASSPATH"); + memKey = "SPARK_DAEMON_MEMORY"; break; case "org.apache.spark.deploy.ExternalShuffleService": case "org.apache.spark.deploy.mesos.MesosExternalShuffleService": From b57ed2245c705fb0964462cf4492b809ade836c6 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Thu, 9 Nov 2017 19:11:30 -0800 Subject: [PATCH 216/712] [SPARK-22308][TEST-MAVEN] Support alternative unit testing styles in external applications Continuation of PR#19528 (https://github.com/apache/spark/pull/19529#issuecomment-340252119) The problem with the maven build in the previous PR was the new tests.... the creation of a spark session outside the tests meant there was more than one spark session around at a time. I was using the spark session outside the tests so that the tests could share data; I've changed it so that each test creates the data anew. Author: Nathan Kronenfeld Author: Nathan Kronenfeld Closes #19705 from nkronenfeld/alternative-style-tests-2. --- .../org/apache/spark/SharedSparkContext.scala | 17 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../spark/sql/test/GenericFlatSpecSuite.scala | 47 +++++ .../spark/sql/test/GenericFunSpecSuite.scala | 49 +++++ .../spark/sql/test/GenericWordSpecSuite.scala | 53 ++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 173 ++++++++++-------- .../spark/sql/test/SharedSQLContext.scala | 84 +-------- .../spark/sql/test/SharedSparkSession.scala | 119 ++++++++++++ 8 files changed, 387 insertions(+), 165 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..14ac479e89754 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..e8971e36d112d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..44655a5345ca4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +import org.apache.spark.sql.Dataset + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + + private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} From 0025ddeb1dd4fd6951ecd8456457f6b94124f84e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Nov 2017 21:56:20 -0800 Subject: [PATCH 217/712] [SPARK-22472][SQL] add null check for top-level primitive values ## What changes were proposed in this pull request? One powerful feature of `Dataset` is, we can easily map SQL rows to Scala/Java objects and do runtime null check automatically. For example, let's say we have a parquet file with schema ``, and we have a `case class Data(a: Int, b: String)`. Users can easily read this parquet file into `Data` objects, and Spark will throw NPE if column `a` has null values. However the null checking is left behind for top-level primitive values. For example, let's say we have a parquet file with schema ``, and we read it into Scala `Int`. If column `a` has null values, we will get some weird results. ``` scala> val ds = spark.read.parquet(...).as[Int] scala> ds.show() +----+ |v | +----+ |null| |1 | +----+ scala> ds.collect res0: Array[Long] = Array(0, 1) scala> ds.map(_ * 2).show +-----+ |value| +-----+ |-2 | |2 | +-----+ ``` This is because internally Spark use some special default values for primitive types, but never expect users to see/operate these default value directly. This PR adds null check for top-level primitive values ## How was this patch tested? new test Author: Wenchen Fan Closes #19707 from cloud-fan/bug. --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +++++++- .../sql/catalyst/ScalaReflectionSuite.scala | 7 ++++++- .../org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f62553ddd3971..4e47a5890db9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -134,7 +134,13 @@ object ScalaReflection extends ScalaReflection { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - deserializerFor(tpe, None, walkedTypePath) + val expr = deserializerFor(tpe, None, walkedTypePath) + val Schema(_, nullable) = schemaFor(tpe) + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } } private def deserializerFor( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f77af5db3279b..23e866cdf4917 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -351,4 +351,9 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(argumentsFields(0) == Seq("field.1")) assert(argumentsFields(1) == Seq("field 2")) } + + test("SPARK-22472: add null check for top-level primitive values") { + assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c67165c7abca6..6e13a5d491e0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1408,6 +1409,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, SpecialCharClass("1", "2")) } } + + test("SPARK-22472: add null check for top-level primitive values") { + // If the primitive values are from Option, we need to do runtime null check. + val ds = Seq(Some(1), None).toDS().as[Int] + intercept[NullPointerException](ds.collect()) + val e = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e.getCause.isInstanceOf[NullPointerException]) + + withTempPath { path => + Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath) + // If the primitive values are from files, we need to do runtime null check. + val ds = spark.read.parquet(path.getCanonicalPath).as[Int] + intercept[NullPointerException](ds.collect()) + val e = intercept[SparkException](ds.map(_ * 2).collect()) + assert(e.getCause.isInstanceOf[NullPointerException]) + } + } } case class SingleData(id: Int) From 28ab5bf59766096bb7e1bdab32b05cb5f0799a48 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 10 Nov 2017 12:01:02 +0100 Subject: [PATCH 218/712] [SPARK-22487][SQL][HIVE] Remove the unused HIVE_EXECUTION_VERSION property ## What changes were proposed in this pull request? At the beginning https://github.com/apache/spark/pull/2843 added `spark.sql.hive.version` to reveal underlying hive version for jdbc connections. For some time afterwards, it was used as a version identifier for the execution hive client. Actually there is no hive client for executions in spark now and there are no usages of HIVE_EXECUTION_VERSION found in whole spark project. HIVE_EXECUTION_VERSION is set by `spark.sql.hive.version`, which is still set internally in some places or by users, this may confuse developers and users with HIVE_METASTORE_VERSION(spark.sql.hive.metastore.version). It might better to be removed. ## How was this patch tested? modify some existing ut cc cloud-fan gatorsmile Author: Kent Yao Closes #19712 from yaooqinn/SPARK-22487. --- .../sql/hive/thriftserver/SparkSQLEnv.scala | 1 - .../thriftserver/SparkSQLSessionManager.scala | 1 - .../HiveThriftServer2Suites.scala | 27 +++++-------------- .../org/apache/spark/sql/hive/HiveUtils.scala | 25 +++++++---------- .../spark/sql/hive/client/VersionsSuite.scala | 4 +-- 5 files changed, 19 insertions(+), 39 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 01c4eb131a564..5db93b26f550e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -55,7 +55,6 @@ private[hive] object SparkSQLEnv extends Logging { metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 7adaafe5ad5c1..00920c297d493 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -77,7 +77,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: } else { sqlContext.newSession() } - ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) if (sessionConf != null && sessionConf.containsKey("use:database")) { ctx.sql(s"use ${sessionConf.get("use:database")}") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4997d7f96afa2..b80596f55bdea 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -155,10 +155,10 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.version") - assert(resultSet.getString(2) === HiveUtils.hiveExecutionVersion) + assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } @@ -521,20 +521,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { conf += resultSet.getString(1) -> resultSet.getString(2) } - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) - } - } - - test("Checks Hive version via SET") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET") - - val conf = mutable.Map.empty[String, String] - while (resultSet.next()) { - conf += resultSet.getString(1) -> resultSet.getString(2) - } - - assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + assert(conf.get("spark.sql.hive.metastore.version") === Some("1.2.1")) } } @@ -721,10 +708,10 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.version") - assert(resultSet.getString(2) === HiveUtils.hiveExecutionVersion) + assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 80b9a3dc9605d..d8e08f1f6df50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -58,28 +58,23 @@ private[spark] object HiveUtils extends Logging { } /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "1.2.1" + val builtinHiveVersion: String = "1.2.1" val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through $hiveExecutionVersion.") + s"0.12.0 through 2.1.1.") .stringConf - .createWithDefault(hiveExecutionVersion) - - val HIVE_EXECUTION_VERSION = buildConf("spark.sql.hive.version") - .doc("Version of Hive used internally by Spark SQL.") - .stringConf - .createWithDefault(hiveExecutionVersion) + .createWithDefault(builtinHiveVersion) val HIVE_METASTORE_JARS = buildConf("spark.sql.hive.metastore.jars") .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. | This property can be one of three options: " | 1. "builtin" - | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly when + | Use Hive ${builtinHiveVersion}, which is bundled with the Spark assembly when | -Phive is enabled. When this option is chosen, | spark.sql.hive.metastore.version must be either - | ${hiveExecutionVersion} or not defined. + | ${builtinHiveVersion} or not defined. | 2. "maven" | Use Hive jars of specified version downloaded from Maven repositories. | 3. A classpath in the standard format for both Hive and Hadoop. @@ -259,9 +254,9 @@ private[spark] object HiveUtils extends Logging { protected[hive] def newClientForExecution( conf: SparkConf, hadoopConf: Configuration): HiveClientImpl = { - logInfo(s"Initializing execution hive, version $hiveExecutionVersion") + logInfo(s"Initializing execution hive, version $builtinHiveVersion") val loader = new IsolatedClientLoader( - version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + version = IsolatedClientLoader.hiveVersion(builtinHiveVersion), sparkConf = conf, execJars = Seq.empty, hadoopConf = hadoopConf, @@ -297,12 +292,12 @@ private[spark] object HiveUtils extends Logging { val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) val isolatedLoader = if (hiveMetastoreJars == "builtin") { - if (hiveExecutionVersion != hiveMetastoreVersion) { + if (builtinHiveVersion != hiveMetastoreVersion) { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + - s"Execution: $hiveExecutionVersion != Metastore: $hiveMetastoreVersion. " + + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } // We recursively find all jars in the class loader chain, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index edb9a9ffbaaf6..9ed39cc80f50d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -73,7 +73,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) + val badClient = buildClient(HiveUtils.builtinHiveVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) badClient.createDatabase(db, ignoreIfExists = true) } @@ -81,7 +81,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test("hadoop configuration preserved") { val hadoopConf = new Configuration() hadoopConf.set("test", "success") - val client = buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) + val client = buildClient(HiveUtils.builtinHiveVersion, hadoopConf) assert("success" === client.getConf("test", null)) } From 9b9827759af2ca3eea146a6032f9165f640ce152 Mon Sep 17 00:00:00 2001 From: Pralabh Kumar Date: Fri, 10 Nov 2017 13:17:25 +0200 Subject: [PATCH 219/712] [SPARK-20199][ML] : Provided featureSubsetStrategy to GBTClassifier and GBTRegressor ## What changes were proposed in this pull request? (Provided featureSubset Strategy to GBTClassifier a) Moved featureSubsetStrategy to TreeEnsembleParams b) Changed GBTClassifier to pass featureSubsetStrategy val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)) ## How was this patch tested? a) Tested GradientBoostedTreeClassifierExample by adding .setFeatureSubsetStrategy with GBTClassifier b)Added test cases in GBTClassifierSuite and GBTRegressorSuite Author: Pralabh Kumar Closes #18118 from pralabhkumar/develop. --- ...GradientBoostedTreeClassifierExample.scala | 1 + .../ml/classification/GBTClassifier.scala | 9 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 8 +- .../spark/ml/regression/GBTRegressor.scala | 9 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../ml/tree/impl/DecisionTreeMetadata.scala | 4 +- .../ml/tree/impl/GradientBoostedTrees.scala | 25 ++++-- .../org/apache/spark/ml/tree/treeParams.scala | 82 ++++++++++--------- .../mllib/tree/GradientBoostedTrees.scala | 4 +- .../spark/mllib/tree/RandomForest.scala | 2 +- .../classification/GBTClassifierSuite.scala | 29 +++++++ .../ml/regression/GBTRegressorSuite.scala | 29 +++++++ .../tree/impl/GradientBoostedTreesSuite.scala | 4 +- 14 files changed, 146 insertions(+), 64 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 9a39acfbf37e5..3656773c8b817 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -59,6 +59,7 @@ object GradientBoostedTreeClassifierExample { .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setMaxIter(10) + .setFeatureSubsetStrategy("auto") // Convert indexed labels back to original labels. val labelConverter = new IndexToString() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3da809ce5f77c..f11bc1d8fe415 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -135,6 +135,11 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = set(stepSize, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + // Parameters from GBTClassifierParams: /** @group setParam */ @@ -167,12 +172,12 @@ class GBTClassifier @Since("1.4.0") ( val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c235209289..78a4972adbdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestClassifier = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 01c5cc1c7efa9..0291a57487c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,12 +117,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S } /** (private[ml]) Train a decision tree on an RDD */ - private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + private[ml] def train( + data: RDD[LabeledPoint], + oldStrategy: OldStrategy, + featureSubsetStrategy: String): DecisionTreeRegressionModel = { val instr = Instrumentation.create(this, data) instr.logParams(params: _*) - val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 08d175cb94442..f41d15b62dddd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -140,6 +140,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.3.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -150,11 +155,11 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed)) + $(seed), $(featureSubsetStrategy)) val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a58da50fad972..200b234b79978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -149,7 +149,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = - RandomForestParams.supportedFeatureSubsetStrategies + TreeEnsembleParams.supportedFeatureSubsetStrategies @Since("2.0.0") override def load(path: String): RandomForestRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 8a9dcb486b7bf..53189e0797b6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -22,7 +22,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.tree.RandomForestParams +import org.apache.spark.ml.tree.TreeEnsembleParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.Strategy @@ -200,7 +200,7 @@ private[spark] object DecisionTreeMetadata extends Logging { Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { case Some(value) => math.ceil(value * numFeatures).toInt case _ => throw new IllegalArgumentException(s"Supported values:" + - s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index e32447a79abb8..bd8c9afb5e209 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -42,16 +42,18 @@ private[spark] object GradientBoostedTrees extends Logging { def run( input: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, + seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, - seed) + seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -73,11 +75,13 @@ private[spark] object GradientBoostedTrees extends Logging { input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, + validate = true, seed, featureSubsetStrategy) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -85,7 +89,7 @@ private[spark] object GradientBoostedTrees extends Logging { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true, seed) + validate = true, seed, featureSubsetStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -245,7 +249,8 @@ private[spark] object GradientBoostedTrees extends Logging { validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, validate: Boolean, - seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + seed: Long, + featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -258,6 +263,7 @@ private[spark] object GradientBoostedTrees extends Logging { val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy val validationTol = boostingStrategy.validationTol @@ -288,7 +294,7 @@ private[spark] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTree = new DecisionTreeRegressor().setSeed(seed) - val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight @@ -319,8 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") + val dt = new DecisionTreeRegressor().setSeed(seed + m) - val model = dt.train(data, treeStrategy) + val model = dt.train(data, treeStrategy, featureSubsetStrategy) timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 47079d9c6bb1c..81b6222acc7ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -320,6 +320,12 @@ private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams } } +private[spark] object TreeEnsembleParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) +} + /** * Parameters for Decision Tree-based ensemble algorithms. * @@ -359,38 +365,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { oldImpurity: OldImpurity): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) } -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * - * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) - * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms - * are a bit different. - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -420,10 +394,10 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s" Supported options: ${TreeEnsembleParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains( + TreeEnsembleParams.supportedFeatureSubsetStrategies.contains( value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,7 +405,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") /** - * @deprecated This method is deprecated and will be removed in 3.0.0. + * @deprecated This method is deprecated and will be removed in 3.0.0 * @group setParam */ @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") @@ -441,10 +415,38 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } -private[spark] object RandomForestParams { - // These options should be lowercase. - final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) + + +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * + * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) + * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms + * are a bit different. + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) } private[ml] trait RandomForestClassifierParams @@ -497,6 +499,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(featureSubsetStrategy -> "all") + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index df2c1b02f4f40..d24d8da0dab48 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -69,7 +69,7 @@ class GradientBoostedTrees private[spark] ( val algo = boostingStrategy.treeStrategy.algo val (trees, treeWeights) = NewGBT.run(input.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } @@ -101,7 +101,7 @@ class GradientBoostedTrees private[spark] ( NewLabeledPoint(point.label, point.features.asML) }, validationInput.map { point => NewLabeledPoint(point.label, point.features.asML) - }, boostingStrategy, seed.toLong) + }, boostingStrategy, seed.toLong, "all") new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index d1331a57de27b..a8c5286f3dc10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} +import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 8000143d4d142..978f89c459f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -83,6 +83,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.getPredictionCol === "prediction") assert(gbt.getRawPredictionCol === "rawPrediction") assert(gbt.getProbabilityCol === "probability") + assert(gbt.getFeatureSubsetStrategy === "all") val df = trainData.toDF() val model = gbt.fit(df) model.transform(df) @@ -95,6 +96,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getPredictionCol === "prediction") assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") + assert(model.getFeatureSubsetStrategy === "all") assert(model.hasParent) MLTestingUtils.checkCopyAndUids(gbt, model) @@ -356,6 +358,33 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(importances.toArray.forall(_ >= 0.0)) } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(3) + .setMaxIter(5) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 2da25f7e0100a..ecbb57126d759 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -165,6 +165,35 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext assert(importances.toArray.forall(_ >= 0.0)) } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature subset strategy + ///////////////////////////////////////////////////////////////////////////// + test("Tests of feature subset strategy") { + val numClasses = 2 + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSeed(123) + .setFeatureSubsetStrategy("all") + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + + // GBT with different featureSubsetStrategy + val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") + val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances + val mostIF = importanceFeatures.argmax + assert(mostImportantFeature !== mostIF) + } + + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 4109a299091dc..366d5ec3a53fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -50,12 +50,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val (validateTrees, validateTreeWeights) = GradientBoostedTrees - .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) + .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all") val numTrees = validateTrees.length assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all") val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) From 1c923d7d65dd94996f0fe2cf9851a1ae738c5c0c Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 10 Nov 2017 12:43:29 +0100 Subject: [PATCH 220/712] [SPARK-22450][CORE][MLLIB] safely register class for mllib ## What changes were proposed in this pull request? There are still some algorithms based on mllib, such as KMeans. For now, many mllib common class (such as: Vector, DenseVector, SparseVector, Matrix, DenseMatrix, SparseMatrix) are not registered in Kryo. So there are some performance issues for those object serialization or deserialization. Previously dicussed: https://github.com/apache/spark/pull/19586 ## How was this patch tested? New test case. Author: Xianyang Liu Closes #19661 from ConeyLiu/register_vector. --- .../spark/serializer/KryoSerializer.scala | 26 ++++++++++ .../spark/ml/feature/InstanceSuit.scala | 47 +++++++++++++++++++ .../spark/mllib/linalg/MatricesSuite.scala | 28 ++++++++++- .../spark/mllib/linalg/VectorsSuite.scala | 19 +++++++- 4 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 58483c9577d29..2259d1a2d555d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -25,6 +25,7 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} @@ -178,6 +179,31 @@ class KryoSerializer(conf: SparkConf) kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$")) kryo.register(classOf[ArrayBuffer[Any]]) + // We can't load those class directly in order to avoid unnecessary jar dependencies. + // We load them safely, ignore it if the class not found. + Seq("org.apache.spark.mllib.linalg.Vector", + "org.apache.spark.mllib.linalg.DenseVector", + "org.apache.spark.mllib.linalg.SparseVector", + "org.apache.spark.mllib.linalg.Matrix", + "org.apache.spark.mllib.linalg.DenseMatrix", + "org.apache.spark.mllib.linalg.SparseMatrix", + "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.linalg.DenseVector", + "org.apache.spark.ml.linalg.SparseVector", + "org.apache.spark.ml.linalg.Matrix", + "org.apache.spark.ml.linalg.DenseMatrix", + "org.apache.spark.ml.linalg.SparseMatrix", + "org.apache.spark.ml.feature.Instance", + "org.apache.spark.ml.feature.OffsetInstance" + ).foreach { name => + try { + val clazz = Utils.classForName(name) + kryo.register(clazz) + } catch { + case NonFatal(_) => // do nothing + } + } + kryo.setClassLoader(classLoader) kryo } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala new file mode 100644 index 0000000000000..88c85a9425e78 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.reflect.ClassTag + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.serializer.KryoSerializer + +class InstanceSuit extends SparkFunSuite{ + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf) + val serInstance = new KryoSerializer(conf).newInstance() + + def check[T: ClassTag](t: T) { + assert(serInstance.deserialize[T](serInstance.serialize(t)) === t) + } + + val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)) + val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse) + val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)) + val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse) + check(instance1) + check(instance2) + check(oInstance1) + check(oInstance2) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index c8ac92eecf40b..d76edb940b2bd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,16 +20,42 @@ package org.apache.spark.mllib.linalg import java.util.Random import scala.collection.mutable.{Map => MutableMap} +import scala.reflect.ClassTag import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when import org.scalatest.mockito.MockitoSugar._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class MatricesSuite extends SparkFunSuite { + test("kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + val m = 3 + val n = 2 + val denseValues = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0) + val denseMat = Matrices.dense(m, n, denseValues).asInstanceOf[DenseMatrix] + + val sparseValues = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val sparseMat = + Matrices.sparse(m, n, colPtrs, rowIndices, sparseValues).asInstanceOf[SparseMatrix] + check(denseMat) + check(sparseMat) + } + test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index a1e3ee54b49ff..4074bead421e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.mllib.linalg +import scala.reflect.ClassTag import scala.util.Random import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM} import org.json4s.jackson.JsonMethods.{parse => parseJson} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class VectorsSuite extends SparkFunSuite with Logging { @@ -34,6 +36,21 @@ class VectorsSuite extends SparkFunSuite with Logging { val indices = Array(0, 2, 3) val values = Array(0.1, 0.3, 0.4) + test("kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + val desVec = Vectors.dense(arr).asInstanceOf[DenseVector] + val sparVec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector] + check(desVec) + check(sparVec) + } + test("dense vector construction with varargs") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.size === arr.length) From 5b41cbf13b6d6f47b8f8f1ffcc7a3348018627ca Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 10 Nov 2017 11:24:24 -0600 Subject: [PATCH 221/712] [SPARK-22473][TEST] Replace deprecated AsyncAssertions.Waiter and methods of java.sql.Date ## What changes were proposed in this pull request? In `spark-sql` module tests there are deprecations warnings caused by the usage of deprecated methods of `java.sql.Date` and the usage of the deprecated `AsyncAssertions.Waiter` class. This PR replace the deprecated methods of `java.sql.Date` with non-deprecated ones (using `Calendar` where needed). It replaces also the deprecated `org.scalatest.concurrent.AsyncAssertions.Waiter` with `org.scalatest.concurrent.Waiters._`. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #19696 from mgaido91/SPARK-22473. --- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 4 ++-- .../sql/execution/streaming/HDFSMetadataLogSuite.scala | 4 ++-- .../spark/sql/streaming/EventTimeWatermarkSuite.scala | 8 ++++++-- .../spark/sql/streaming/StreamingQueryListenerSuite.scala | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6e13a5d491e0f..b02db7721aa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1274,8 +1274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => scala.math.BigDecimal(1, 18) }.head == scala.math.BigDecimal(1, 18)) - assert(spark.range(1).map { x => new java.sql.Date(2016, 12, 12) }.head == - new java.sql.Date(2016, 12, 12)) + assert(spark.range(1).map { x => java.sql.Date.valueOf("2016-12-12") }.head == + java.sql.Date.valueOf("2016-12-12")) assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 48e70e48b1799..4677769c12a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -26,10 +26,10 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.scalatest.concurrent.AsyncAssertions._ +import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index f3e8cf950a5a4..47bc452bda0d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.{util => ju} import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Calendar, Date} import org.scalatest.{BeforeAndAfter, Matchers} @@ -218,7 +218,11 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .agg(count("*") as 'count) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - def monthsSinceEpoch(date: Date): Int = { date.getYear * 12 + date.getMonth } + def monthsSinceEpoch(date: Date): Int = { + val cal = Calendar.getInstance() + cal.setTime(date) + cal.get(Calendar.YEAR) * 12 + cal.get(Calendar.MONTH) + } testStream(aggWithWatermark)( AddData(input, currentTimeMs / 1000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 1fe639fcf2840..9ff02dee288fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -25,8 +25,8 @@ import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.PrivateMethodTester._ -import org.scalatest.concurrent.AsyncAssertions.Waiter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Waiters.Waiter import org.apache.spark.SparkException import org.apache.spark.scheduler._ From b70aa9e08b4476746e912c2c2a8b7bdd102305e8 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 10 Nov 2017 10:22:42 -0800 Subject: [PATCH 222/712] [SPARK-22344][SPARKR] clean up install dir if running test as source package ## What changes were proposed in this pull request? remove spark if spark downloaded & installed ## How was this patch tested? manually by building package Jenkins, AppVeyor Author: Felix Cheung Closes #19657 from felixcheung/rinstalldir. --- R/pkg/R/install.R | 37 +++++++++++++++++++++++++++- R/pkg/R/utils.R | 13 ++++++++++ R/pkg/tests/fulltests/test_utils.R | 25 +++++++++++++++++++ R/pkg/tests/run-all.R | 4 ++- R/pkg/vignettes/sparkr-vignettes.Rmd | 6 ++++- 5 files changed, 82 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 492dee68e164d..04dc7562e5346 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -152,6 +152,11 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, }) if (!tarExists || overwrite || !success) { unlink(packageLocalPath) + if (success) { + # if tar file was not there before (or it was, but we are told to overwrite it), + # and untar is successful - set a flag that we have downloaded (and untar) Spark package. + assign(".sparkDownloaded", TRUE, envir = .sparkREnv) + } } if (!success) stop("Extract archive failed.") message("DONE.") @@ -266,6 +271,7 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context +# see also sparkCacheRelPathLength() sparkCachePath <- function() { if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) @@ -282,7 +288,7 @@ sparkCachePath <- function() { } } else if (.Platform$OS.type == "unix") { if (Sys.info()["sysname"] == "Darwin") { - path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark") + path <- file.path(Sys.getenv("HOME"), "Library", "Caches", "spark") } else { path <- file.path( Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark") @@ -293,6 +299,16 @@ sparkCachePath <- function() { normalizePath(path, mustWork = FALSE) } +# Length of the Spark cache specific relative path segments for each platform +# eg. "Apache\Spark\Cache" is 3 in Windows, or "spark" is 1 in unix +# Must match sparkCachePath() exactly. +sparkCacheRelPathLength <- function() { + if (is_windows()) { + 3 + } else { + 1 + } +} installInstruction <- function(mode) { if (mode == "remote") { @@ -310,3 +326,22 @@ installInstruction <- function(mode) { stop(paste0("No instruction found for ", mode, " mode.")) } } + +uninstallDownloadedSpark <- function() { + # clean up if Spark was downloaded + sparkDownloaded <- getOne(".sparkDownloaded", + envir = .sparkREnv, + inherits = TRUE, + ifnotfound = FALSE) + sparkDownloadedDir <- Sys.getenv("SPARK_HOME") + if (sparkDownloaded && nchar(sparkDownloadedDir) > 0) { + unlink(sparkDownloadedDir, recursive = TRUE, force = TRUE) + + dirs <- traverseParentDirs(sparkCachePath(), sparkCacheRelPathLength()) + lapply(dirs, function(d) { + if (length(list.files(d, all.files = TRUE, include.dirs = TRUE, no.. = TRUE)) == 0) { + unlink(d, recursive = TRUE, force = TRUE) + } + }) + } +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fa4099231ca8d..164cd6d01a347 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -910,3 +910,16 @@ hadoop_home_set <- function() { windows_with_hadoop <- function() { !is_windows() || hadoop_home_set() } + +# get0 not supported before R 3.2.0 +getOne <- function(x, envir, inherits = TRUE, ifnotfound = NULL) { + mget(x[1L], envir = envir, inherits = inherits, ifnotfound = list(ifnotfound))[[1L]] +} + +# Returns a vector of parent directories, traversing up count times, starting with a full path +# eg. traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) should return +# this "/Users/user/Library/Caches/spark/spark2.2" +# and "/Users/user/Library/Caches/spark" +traverseParentDirs <- function(x, count) { + if (dirname(x) == x || count <= 0) x else c(x, Recall(dirname(x), count - 1)) +} diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index fb394b8069c1c..f0292ab335592 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -228,4 +228,29 @@ test_that("basenameSansExtFromUrl", { expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") }) +test_that("getOne", { + dummy <- getOne(".dummyValue", envir = new.env(), ifnotfound = FALSE) + expect_equal(dummy, FALSE) +}) + +test_that("traverseParentDirs", { + if (is_windows()) { + # original path is included as-is, otherwise dirname() replaces \\ with / on windows + dirs <- traverseParentDirs("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2", 3) + expect <- c("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2", + "c:/Users/user/AppData/Local/Apache/Spark/Cache", + "c:/Users/user/AppData/Local/Apache/Spark", + "c:/Users/user/AppData/Local/Apache") + expect_equal(dirs, expect) + } else { + dirs <- traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) + expect <- c("/Users/user/Library/Caches/spark/spark2.2", "/Users/user/Library/Caches/spark") + expect_equal(dirs, expect) + + dirs <- traverseParentDirs("/home/u/.cache/spark/spark2.2", 1) + expect <- c("/home/u/.cache/spark/spark2.2", "/home/u/.cache/spark") + expect_equal(dirs, expect) + } +}) + sparkR.session.stop() diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index a7f913e5fad11..63812ba70bb50 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -46,7 +46,7 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { tmpDir <- tempdir() tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, - spark.executor.extraJavaOptions = tmpArg) + spark.executor.extraJavaOptions = tmpArg) } test_package("SparkR") @@ -60,3 +60,5 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { NULL, "summary") } + +SparkR:::uninstallDownloadedSpark() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 907bbb3d66018..8c4ea2f2db188 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -37,7 +37,7 @@ opts_hooks$set(eval = function(options) { options }) r_tmp_dir <- tempdir() -tmp_arg <- paste("-Djava.io.tmpdir=", r_tmp_dir, sep = "") +tmp_arg <- paste0("-Djava.io.tmpdir=", r_tmp_dir) sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg, spark.executor.extraJavaOptions = tmp_arg) old_java_opt <- Sys.getenv("_JAVA_OPTIONS") @@ -1183,3 +1183,7 @@ env | map ```{r, echo=FALSE} sparkR.session.stop() ``` + +```{r cleanup, include=FALSE} +SparkR:::uninstallDownloadedSpark() +``` From 5ebdcd185f2108a90e37a1aa4214c3b6c69a97a4 Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Fri, 10 Nov 2017 10:57:58 -0800 Subject: [PATCH 223/712] [SPARK-22294][DEPLOY] Reset spark.driver.bindAddress when starting a Checkpoint ## What changes were proposed in this pull request? It seems that recovering from a checkpoint can replace the old driver and executor IP addresses, as the workload can now be taking place in a different cluster configuration. It follows that the bindAddress for the master may also have changed. Thus we should not be keeping the old one, and instead be added to the list of properties to reset and recreate from the new environment. ## How was this patch tested? This patch was tested via manual testing on AWS, using the experimental (not yet merged) Kubernetes scheduler, which uses bindAddress to bind to a Kubernetes service (and thus was how I first encountered the bug too), but it is not a code-path related to the scheduler and this may have slipped through when merging SPARK-4563. Author: Santiago Saavedra Closes #19427 from ssaavedra/fix-checkpointing-master. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 3cfbcedd519d6..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -51,6 +51,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.app.id", "spark.yarn.app.attemptId", "spark.driver.host", + "spark.driver.bindAddress", "spark.driver.port", "spark.master", "spark.yarn.jars", @@ -64,6 +65,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") + .remove("spark.driver.bindAddress") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From 24ea781cd30fbc611c714b5c6931f7d993f3a08d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 10 Nov 2017 11:27:28 -0800 Subject: [PATCH 224/712] [SPARK-19644][SQL] Clean up Scala reflection garbage after creating Encoder ## What changes were proposed in this pull request? Because of the memory leak issue in `scala.reflect.api.Types.TypeApi.<:<` (https://github.com/scala/bug/issues/8302), creating an encoder may leak memory. This PR adds `cleanUpReflectionObjects` to clean up these leaking objects for methods calling `scala.reflect.api.Types.TypeApi.<:<`. ## How was this patch tested? The updated unit tests. Author: Shixiong Zhu Closes #19687 from zsxwing/SPARK-19644. --- .../spark/sql/catalyst/ScalaReflection.scala | 25 +++++++++---- .../encoders/ExpressionEncoderSuite.scala | 35 +++++++++++++++++-- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4e47a5890db9c..65040f1af4b04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -61,7 +61,7 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - private def dataTypeFor(tpe: `Type`): DataType = { + private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType @@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection { * Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): ObjectType = { + private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { val cls = tpe.dealias match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection { private def deserializerFor( tpe: `Type`, path: Option[Expression], - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { @@ -441,7 +441,7 @@ object ScalaReflection extends ScalaReflection { inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String], - seenTypeSet: Set[`Type`] = Set.empty): Expression = { + seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { @@ -648,7 +648,7 @@ object ScalaReflection extends ScalaReflection { * Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that, * we also treat [[DefinedByConstructorParams]] as product type. */ - def optionOfProductType(tpe: `Type`): Boolean = { + def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects { tpe.dealias match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -710,7 +710,7 @@ object ScalaReflection extends ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = { + def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() @@ -780,7 +780,7 @@ object ScalaReflection extends ScalaReflection { /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ - def definedByConstructorParams(tpe: Type): Boolean = { + def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] } @@ -809,6 +809,17 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + /** + * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to + * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to + * `scala.reflect.runtime.JavaUniverse.undoLog`. + * + * @see https://github.com/scala/bug/issues/8302 + */ + def cleanUpReflectionObjects[T](func: => T): T = { + universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) + } + /** * Return the Scala Type for `T` in the current classloader mirror. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bb1955a1ae242..e6d09bdae67d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ClosureCleaner case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -114,7 +115,9 @@ object ReferenceValueClass { class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) - implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects { + ExpressionEncoder() + } // test flat encoders encodeDecodeTest(false, "primitive boolean") @@ -370,8 +373,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { - test(s"encode/decode for $testName: $input") { + testAndVerifyNotLeakingReflectionObjects(s"encode/decode for $testName: $input") { val encoder = implicitly[ExpressionEncoder[T]] + + // Make sure encoder is serializable. + ClosureCleaner.clean((s: String) => encoder.getClass.getName) + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolveAndBind() @@ -441,4 +448,28 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } } + + /** + * Verify the size of scala.reflect.runtime.JavaUniverse.undoLog before and after `func` to + * ensure we don't leak Scala reflection garbage. + * + * @see org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects + */ + private def verifyNotLeakingReflectionObjects[T](func: => T): T = { + def undoLogSize: Int = { + scala.reflect.runtime.universe + .asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.log.size + } + + val previousUndoLogSize = undoLogSize + val r = func + assert(previousUndoLogSize == undoLogSize) + r + } + + private def testAndVerifyNotLeakingReflectionObjects(testName: String)(testFun: => Any) { + test(testName) { + verifyNotLeakingReflectionObjects(testFun) + } + } } From f2da738c76810131045e6c32533a2d13526cdaf6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Nov 2017 21:17:49 +0100 Subject: [PATCH 225/712] [SPARK-22284][SQL] Fix 64KB JVM bytecode limit problem in calculating hash for nested structs ## What changes were proposed in this pull request? This PR avoids to generate a huge method for calculating a murmur3 hash for nested structs. This PR splits a huge method (e.g. `apply_4`) into multiple smaller methods. Sample program ``` val structOfString = new StructType().add("str", StringType) var inner = new StructType() for (_ <- 0 until 800) { inner = inner1.add("structOfString", structOfString) } var schema = new StructType() for (_ <- 0 until 50) { schema = schema.add("structOfStructOfStrings", inner) } GenerateMutableProjection.generate(Seq(Murmur3Hash(exprs, 42))) ``` Without this PR ``` /* 005 */ class SpecificMutableProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseMutableProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private InternalRow mutableRow; /* 009 */ private int value; /* 010 */ private int value_0; ... /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ value = 42; /* 040 */ apply_0(i); /* 041 */ apply_1(i); /* 042 */ apply_2(i); /* 043 */ apply_3(i); /* 044 */ apply_4(i); /* 045 */ nestedClassInstance.apply_5(i); ... /* 089 */ nestedClassInstance8.apply_49(i); /* 090 */ value_0 = value; /* 091 */ /* 092 */ // copy all the results into MutableRow /* 093 */ mutableRow.setInt(0, value_0); /* 094 */ return mutableRow; /* 095 */ } /* 096 */ /* 097 */ /* 098 */ private void apply_4(InternalRow i) { /* 099 */ /* 100 */ boolean isNull5 = i.isNullAt(4); /* 101 */ InternalRow value5 = isNull5 ? null : (i.getStruct(4, 800)); /* 102 */ if (!isNull5) { /* 103 */ /* 104 */ if (!value5.isNullAt(0)) { /* 105 */ /* 106 */ final InternalRow element6400 = value5.getStruct(0, 1); /* 107 */ /* 108 */ if (!element6400.isNullAt(0)) { /* 109 */ /* 110 */ final UTF8String element6401 = element6400.getUTF8String(0); /* 111 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6401.getBaseObject(), element6401.getBaseOffset(), element6401.numBytes(), value); /* 112 */ /* 113 */ } /* 114 */ /* 115 */ /* 116 */ } /* 117 */ /* 118 */ /* 119 */ if (!value5.isNullAt(1)) { /* 120 */ /* 121 */ final InternalRow element6402 = value5.getStruct(1, 1); /* 122 */ /* 123 */ if (!element6402.isNullAt(0)) { /* 124 */ /* 125 */ final UTF8String element6403 = element6402.getUTF8String(0); /* 126 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6403.getBaseObject(), element6403.getBaseOffset(), element6403.numBytes(), value); /* 127 */ /* 128 */ } /* 128 */ } /* 129 */ /* 130 */ /* 131 */ } /* 132 */ /* 133 */ /* 134 */ if (!value5.isNullAt(2)) { /* 135 */ /* 136 */ final InternalRow element6404 = value5.getStruct(2, 1); /* 137 */ /* 138 */ if (!element6404.isNullAt(0)) { /* 139 */ /* 140 */ final UTF8String element6405 = element6404.getUTF8String(0); /* 141 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6405.getBaseObject(), element6405.getBaseOffset(), element6405.numBytes(), value); /* 142 */ /* 143 */ } /* 144 */ /* 145 */ /* 146 */ } /* 147 */ ... /* 12074 */ if (!value5.isNullAt(798)) { /* 12075 */ /* 12076 */ final InternalRow element7996 = value5.getStruct(798, 1); /* 12077 */ /* 12078 */ if (!element7996.isNullAt(0)) { /* 12079 */ /* 12080 */ final UTF8String element7997 = element7996.getUTF8String(0); /* 12083 */ } /* 12084 */ /* 12085 */ /* 12086 */ } /* 12087 */ /* 12088 */ /* 12089 */ if (!value5.isNullAt(799)) { /* 12090 */ /* 12091 */ final InternalRow element7998 = value5.getStruct(799, 1); /* 12092 */ /* 12093 */ if (!element7998.isNullAt(0)) { /* 12094 */ /* 12095 */ final UTF8String element7999 = element7998.getUTF8String(0); /* 12096 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element7999.getBaseObject(), element7999.getBaseOffset(), element7999.numBytes(), value); /* 12097 */ /* 12098 */ } /* 12099 */ /* 12100 */ /* 12101 */ } /* 12102 */ /* 12103 */ } /* 12104 */ /* 12105 */ } /* 12106 */ /* 12106 */ /* 12107 */ /* 12108 */ private void apply_1(InternalRow i) { ... ``` With this PR ``` /* 005 */ class SpecificMutableProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseMutableProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private InternalRow mutableRow; /* 009 */ private int value; /* 010 */ private int value_0; /* 011 */ ... /* 034 */ public java.lang.Object apply(java.lang.Object _i) { /* 035 */ InternalRow i = (InternalRow) _i; /* 036 */ /* 037 */ /* 038 */ /* 039 */ value = 42; /* 040 */ nestedClassInstance11.apply50_0(i); /* 041 */ nestedClassInstance11.apply50_1(i); ... /* 088 */ nestedClassInstance11.apply50_48(i); /* 089 */ nestedClassInstance11.apply50_49(i); /* 090 */ value_0 = value; /* 091 */ /* 092 */ // copy all the results into MutableRow /* 093 */ mutableRow.setInt(0, value_0); /* 094 */ return mutableRow; /* 095 */ } /* 096 */ ... /* 37717 */ private void apply4_0(InternalRow value5, InternalRow i) { /* 37718 */ /* 37719 */ if (!value5.isNullAt(0)) { /* 37720 */ /* 37721 */ final InternalRow element6400 = value5.getStruct(0, 1); /* 37722 */ /* 37723 */ if (!element6400.isNullAt(0)) { /* 37724 */ /* 37725 */ final UTF8String element6401 = element6400.getUTF8String(0); /* 37726 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6401.getBaseObject(), element6401.getBaseOffset(), element6401.numBytes(), value); /* 37727 */ /* 37728 */ } /* 37729 */ /* 37730 */ /* 37731 */ } /* 37732 */ /* 37733 */ if (!value5.isNullAt(1)) { /* 37734 */ /* 37735 */ final InternalRow element6402 = value5.getStruct(1, 1); /* 37736 */ /* 37737 */ if (!element6402.isNullAt(0)) { /* 37738 */ /* 37739 */ final UTF8String element6403 = element6402.getUTF8String(0); /* 37740 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6403.getBaseObject(), element6403.getBaseOffset(), element6403.numBytes(), value); /* 37741 */ /* 37742 */ } /* 37743 */ /* 37744 */ /* 37745 */ } /* 37746 */ /* 37747 */ if (!value5.isNullAt(2)) { /* 37748 */ /* 37749 */ final InternalRow element6404 = value5.getStruct(2, 1); /* 37750 */ /* 37751 */ if (!element6404.isNullAt(0)) { /* 37752 */ /* 37753 */ final UTF8String element6405 = element6404.getUTF8String(0); /* 37754 */ value = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(element6405.getBaseObject(), element6405.getBaseOffset(), element6405.numBytes(), value); /* 37755 */ /* 37756 */ } /* 37757 */ /* 37758 */ /* 37759 */ } /* 37760 */ /* 37761 */ } ... /* 218470 */ /* 218471 */ private void apply50_4(InternalRow i) { /* 218472 */ /* 218473 */ boolean isNull5 = i.isNullAt(4); /* 218474 */ InternalRow value5 = isNull5 ? null : (i.getStruct(4, 800)); /* 218475 */ if (!isNull5) { /* 218476 */ apply4_0(value5, i); /* 218477 */ apply4_1(value5, i); /* 218478 */ apply4_2(value5, i); ... /* 218742 */ nestedClassInstance.apply4_266(value5, i); /* 218743 */ } /* 218744 */ /* 218745 */ } ``` ## How was this patch tested? Added new test to `HashExpressionsSuite` Author: Kazuaki Ishizaki Closes #19563 from kiszk/SPARK-22284. --- .../spark/sql/catalyst/expressions/hash.scala | 5 ++-- .../expressions/HashExpressionsSuite.scala | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 85a5f7fb2c6c3..eb3c49f5cf30e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -389,9 +389,10 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { - fields.zipWithIndex.map { case (field, index) => + val hashes = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) - }.mkString("\n") + } + ctx.splitExpressions(input, hashes) } @tailrec diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 59fc8eaf73d61..112a4a09728ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -639,6 +639,35 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval) } + test("SPARK-22284: Compute hash for nested structs") { + val M = 80 + val N = 10 + val L = M * N + val O = 50 + val seed = 42 + + val wideRow = new GenericInternalRow(Seq.tabulate(O)(k => + new GenericInternalRow(Seq.tabulate(M)(j => + new GenericInternalRow(Seq.tabulate(N)(i => + new GenericInternalRow(Array[Any]( + UTF8String.fromString((k * L + j * N + i).toString)))) + .toArray[Any])).toArray[Any])).toArray[Any]) + val inner = new StructType( + (0 until N).map(_ => StructField("structOfString", structOfString)).toArray) + val outer = new StructType( + (0 until M).map(_ => StructField("structOfStructOfString", inner)).toArray) + val schema = new StructType( + (0 until O).map(_ => StructField("structOfStructOfStructOfString", outer)).toArray) + val exprs = schema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val murmur3HashExpr = Murmur3Hash(exprs, 42) + val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) + + val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow) + assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval) + } + private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema) From 808e886b9638ab2981dac676b594f09cda9722fe Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Fri, 10 Nov 2017 15:18:11 -0800 Subject: [PATCH 226/712] [SPARK-21667][STREAMING] ConsoleSink should not fail streaming query with checkpointLocation option ## What changes were proposed in this pull request? Fix to allow recovery on console , avoid checkpoint exception ## How was this patch tested? existing tests manual tests [ Replicating error and seeing no checkpoint error after fix] Author: Rekha Joshi Author: rjoshi2 Closes #19407 from rekhajoshm/SPARK-21667. --- .../apache/spark/sql/streaming/DataStreamWriter.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 14e7df672cc58..0be69b98abc8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -267,12 +267,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val (useTempCheckpointLocation, recoverFromCheckpointLocation) = - if (source == "console") { - (true, false) - } else { - (false, true) - } val dataSource = DataSource( df.sparkSession, @@ -285,8 +279,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { df, dataSource.createSink(outputMode), outputMode, - useTempCheckpointLocation = useTempCheckpointLocation, - recoverFromCheckpointLocation = recoverFromCheckpointLocation, + useTempCheckpointLocation = source == "console", + recoverFromCheckpointLocation = true, trigger = trigger) } } From 3eb315d7141d69ac040dcba498dd863b6d217775 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 11 Nov 2017 04:10:54 -0600 Subject: [PATCH 227/712] [SPARK-19759][ML] not using blas in ALSModel.predict for optimization ## What changes were proposed in this pull request? In `ALS.predict` currently we are using `blas.sdot` function to perform a dot product on two `Seq`s. It turns out that this is not the most efficient way. I used the following code to compare the implementations: ``` def time[R](block: => R): Unit = { val t0 = System.nanoTime() block val t1 = System.nanoTime() println("Elapsed time: " + (t1 - t0) + "ns") } val r = new scala.util.Random(100) val input = (1 to 500000).map(_ => (1 to 100).map(_ => r.nextFloat).toSeq) def f(a:Seq[Float], b:Seq[Float]): Float = { var r = 0.0f for(i <- 0 until a.length) { r+=a(i)*b(i) } r } import com.github.fommil.netlib.BLAS.{getInstance => blas} val b = (1 to 100).map(_ => r.nextFloat).toSeq time { input.foreach(a=>blas.sdot(100, a.toArray, 1, b.toArray, 1)) } // on average it takes 2968718815 ns time { input.foreach(a=>f(a,b)) } // on average it takes 515510185 ns ``` Thus this PR proposes the old-style for loop implementation for performance reasons. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #19685 from mgaido91/SPARK-19759. --- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a8843661c873b..81a8f50761e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -289,9 +289,13 @@ class ALSModel private[ml] ( private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => if (featuresA != null && featuresB != null) { - // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for - // potential optimization. - blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + var dotProduct = 0.0f + var i = 0 + while (i < rank) { + dotProduct += featuresA(i) * featuresB(i) + i += 1 + } + dotProduct } else { Float.NaN } From 223d83ee93e604009afea4af3d13a838d08625a4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 11 Nov 2017 19:16:31 +0900 Subject: [PATCH 228/712] [SPARK-22476][R] Add dayofweek function to R ## What changes were proposed in this pull request? This PR adds `dayofweek` to R API: ```r data <- list(list(d = as.Date("2012-12-13")), list(d = as.Date("2013-12-14")), list(d = as.Date("2014-12-15"))) df <- createDataFrame(data) collect(select(df, dayofweek(df$d))) ``` ``` dayofweek(d) 1 5 2 7 3 2 ``` ## How was this patch tested? Manual tests and unit tests in `R/pkg/tests/fulltests/test_sparkSQL.R` Author: hyukjinkwon Closes #19706 from HyukjinKwon/add-dayofweek. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 ++++++++++++++++- R/pkg/R/generics.R | 5 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3fc756b9ef40c..57838f52eac3f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -232,6 +232,7 @@ exportMethods("%<=>%", "date_sub", "datediff", "dayofmonth", + "dayofweek", "dayofyear", "decode", "dense_rank", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0143a3e63ba61..237ef061e8071 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -696,7 +696,7 @@ setMethod("hash", #' #' \dontrun{ #' head(select(df, df$time, year(df$time), quarter(df$time), month(df$time), -#' dayofmonth(df$time), dayofyear(df$time), weekofyear(df$time))) +#' dayofmonth(df$time), dayofweek(df$time), dayofyear(df$time), weekofyear(df$time))) #' head(agg(groupBy(df, year(df$time)), count(df$y), avg(df$y))) #' head(agg(groupBy(df, month(df$time)), avg(df$y)))} #' @note dayofmonth since 1.5.0 @@ -707,6 +707,21 @@ setMethod("dayofmonth", column(jc) }) +#' @details +#' \code{dayofweek}: Extracts the day of the week as an integer from a +#' given date/timestamp/string. +#' +#' @rdname column_datetime_functions +#' @aliases dayofweek dayofweek,Column-method +#' @export +#' @note dayofweek since 2.3.0 +setMethod("dayofweek", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofweek", x@jc) + column(jc) + }) + #' @details #' \code{dayofyear}: Extracts the day of the year as an integer from a #' given date/timestamp/string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8312d417b99d2..8fcf269087c7d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1048,6 +1048,11 @@ setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) + #' @rdname column_datetime_functions #' @export #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a0dbd475f78e6..8a7fb124bd9db 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1699,6 +1699,7 @@ test_that("date functions on a DataFrame", { list(a = 2L, b = as.Date("2013-12-14")), list(a = 3L, b = as.Date("2014-12-15"))) df <- createDataFrame(l) + expect_equal(collect(select(df, dayofweek(df$b)))[, 1], c(5, 7, 2)) expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) From 154351e6dbd24c4254094477e3f7defcba979b1a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 11 Nov 2017 12:34:30 +0100 Subject: [PATCH 229/712] [SPARK-22462][SQL] Make rdd-based actions in Dataset trackable in SQL UI ## What changes were proposed in this pull request? For the few Dataset actions such as `foreach`, currently no SQL metrics are visible in the SQL tab of SparkUI. It is because it binds wrongly to Dataset's `QueryExecution`. As the actions directly evaluate on the RDD which has individual `QueryExecution`, to show correct SQL metrics on UI, we should bind to RDD's `QueryExecution`. ## How was this patch tested? Manually test. Screenshot is attached in the PR. Author: Liang-Chi Hsieh Closes #19689 from viirya/SPARK-22462. --- .../scala/org/apache/spark/sql/Dataset.scala | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5eb2affa0bd8f..1620ab3aa2094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2594,7 +2594,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreach(f: T => Unit): Unit = withNewExecutionId { + def foreach(f: T => Unit): Unit = withNewRDDExecutionId { rdd.foreach(f) } @@ -2613,7 +2613,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId { + def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId { rdd.foreachPartition(f) } @@ -2851,6 +2851,12 @@ class Dataset[T] private[sql]( */ def unpersist(): this.type = unpersist(blocking = false) + // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. + @transient private lazy val rddQueryExecution: QueryExecution = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + sparkSession.sessionState.executePlan(deserialized) + } + /** * Represents the content of the Dataset as an `RDD` of `T`. * @@ -2859,8 +2865,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) - sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => + rddQueryExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } } @@ -3113,6 +3118,20 @@ class Dataset[T] private[sql]( SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } + /** + * Wrap an action of the Dataset's RDD to track all Spark jobs in the body so that we can connect + * them with an execution. Before performing the action, the metrics of the executed plan will be + * reset. + */ + private def withNewRDDExecutionId[U](body: => U): U = { + SQLExecution.withNewExecutionId(sparkSession, rddQueryExecution) { + rddQueryExecution.executedPlan.foreach { plan => + plan.resetMetrics() + } + body + } + } + /** * Wrap a Dataset action to track the QueryExecution and time cost, then report to the * user-registered callback functions. From d6ee69e7761a62e47e21e4c2ce1bb20038d745b6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 11 Nov 2017 18:20:11 +0100 Subject: [PATCH 230/712] [SPARK-22488][SQL] Fix the view resolution issue in the SparkSession internal table() API ## What changes were proposed in this pull request? The current internal `table()` API of `SparkSession` bypasses the Analyzer and directly calls `sessionState.catalog.lookupRelation` API. This skips the view resolution logics in our Analyzer rule `ResolveRelations`. This internal API is widely used by various DDL commands, public and internal APIs. Users might get the strange error caused by view resolution when the default database is different. ``` Table or view not found: t1; line 1 pos 14 org.apache.spark.sql.AnalysisException: Table or view not found: t1; line 1 pos 14 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) ``` This PR is to fix it by enforcing it to use `ResolveRelations` to resolve the table. ## How was this patch tested? Added a test case and modified the existing test cases Author: gatorsmile Closes #19713 from gatorsmile/viewResolution. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- .../org/apache/spark/sql/SparkSession.scala | 3 ++- .../spark/sql/execution/command/cache.scala | 4 +--- .../sql/execution/GlobalTempViewSuite.scala | 16 +++++++++++----- .../spark/sql/execution/SQLViewSuite.scala | 15 +++++++++++++++ .../spark/sql/execution/command/DDLSuite.scala | 5 +++-- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- .../apache/spark/sql/hive/CachedTableSuite.scala | 14 +++++++++----- 8 files changed, 43 insertions(+), 18 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8a7fb124bd9db..00217c892fc8d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -733,7 +733,7 @@ test_that("test cache, uncache and clearCache", { expect_true(dropTempView("table1")) expect_error(uncacheTable("foo"), - "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") + "Error in uncacheTable : analysis error - Table or view not found: foo") }) test_that("insertInto() on a registered table", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d5ab53ad8fe29..2821f5ee7feee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} @@ -621,7 +622,7 @@ class SparkSession private( } private[sql] def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent)) + Dataset.ofRows(self, UnresolvedRelation(tableIdent)) } /* ----------------- * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 687994d82a003..6b00426d2fa91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -54,10 +54,8 @@ case class UncacheTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val tableId = tableIdent.quotedString - try { + if (!ifExists || sparkSession.catalog.tableExists(tableId)) { sparkSession.catalog.uncacheTable(tableId) - } catch { - case _: NoSuchTableException if ifExists => // don't throw } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index a3d75b221ec3e..cc943e0356f2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -35,23 +35,27 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { private var globalTempDB: String = _ test("basic semantic") { + val expectedErrorMsg = "not found" try { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, // try table/view in current database, which is "default" in this case. So we expect // NoSuchTableException here. - intercept[NoSuchTableException](spark.table("src")) + var e = intercept[AnalysisException](spark.table("src")).getMessage + assert(e.contains(expectedErrorMsg)) // Use qualified name to refer to the global temp view explicitly. checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) // Table name without database will never refer to a global temp view. - intercept[NoSuchTableException](sql("DROP VIEW src")) + e = intercept[AnalysisException](sql("DROP VIEW src")).getMessage + assert(e.contains(expectedErrorMsg)) sql(s"DROP VIEW $globalTempDB.src") // The global temp view should be dropped successfully. - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage + assert(e.contains(expectedErrorMsg)) // We can also use Dataset API to create global temp view Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") @@ -59,7 +63,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // Use qualified name to rename a global temp view. sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage + assert(e.contains(expectedErrorMsg)) checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) // Use qualified name to alter a global temp view. @@ -68,7 +73,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Catalog API to drop global temp view spark.catalog.dropGlobalTempView("src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + e = intercept[AnalysisException](spark.table(s"$globalTempDB.src2")).getMessage + assert(e.contains(expectedErrorMsg)) // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 6761f05bb462a..08a4a21b20f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -679,4 +679,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { assert(spark.table("v").schema.head.name == "cBa") } } + + test("sparkSession API view resolution with different default database") { + withDatabase("db2") { + withView("v1") { + withTable("t1") { + sql("USE default") + sql("CREATE TABLE t1 USING parquet AS SELECT 1 AS c0") + sql("CREATE VIEW v1 AS SELECT * FROM t1") + sql("CREATE DATABASE IF NOT EXISTS db2") + sql("USE db2") + checkAnswer(spark.table("default.v1"), Row(1)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 21a2c62929146..878f435c75cb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -825,10 +825,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") checkAnswer(spark.table("tab2"), spark.range(10).toDF()) - intercept[NoSuchTableException] { spark.table("tab1") } + val e = intercept[AnalysisException](spark.table("tab1")).getMessage + assert(e.contains("Table or view not found")) sql("ALTER VIEW tab2 RENAME TO tab1") checkAnswer(spark.table("tab1"), spark.range(10).toDF()) - intercept[NoSuchTableException] { spark.table("tab2") } + intercept[AnalysisException] { spark.table("tab2") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index c45b507d2b489..a538b9458177e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -326,7 +326,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic assert(ColumnsRequired.set === requiredColumnNames) val table = spark.table("oneToTenFiltered") - val relation = table.queryExecution.logical.collectFirst { + val relation = table.queryExecution.analyzed.collectFirst { case LogicalRelation(r, _, _, _) => r }.get diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index d3cbf898e2439..48ab4eb9a6178 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -102,14 +102,18 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } test("uncache of nonexistant tables") { + val expectedErrorMsg = "Table or view not found: nonexistantTable" // make sure table doesn't exist - intercept[NoSuchTableException](spark.table("nonexistantTable")) - intercept[NoSuchTableException] { + var e = intercept[AnalysisException](spark.table("nonexistantTable")).getMessage + assert(e.contains(expectedErrorMsg)) + e = intercept[AnalysisException] { spark.catalog.uncacheTable("nonexistantTable") - } - intercept[NoSuchTableException] { + }.getMessage + assert(e.contains(expectedErrorMsg)) + e = intercept[AnalysisException] { sql("UNCACHE TABLE nonexistantTable") - } + }.getMessage + assert(e.contains(expectedErrorMsg)) sql("UNCACHE TABLE IF EXISTS nonexistantTable") } From 21a7bfd5c324e6c82152229f1394f26afeae771c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 11 Nov 2017 22:40:26 +0100 Subject: [PATCH 231/712] [SPARK-10365][SQL] Support Parquet logical type TIMESTAMP_MICROS ## What changes were proposed in this pull request? This PR makes Spark to be able to read Parquet TIMESTAMP_MICROS values, and add a new config to allow Spark to write timestamp values to parquet as TIMESTAMP_MICROS type. ## How was this patch tested? new test Author: Wenchen Fan Closes #19702 from cloud-fan/parquet. --- .../apache/spark/sql/internal/SQLConf.scala | 30 ++++- .../SpecificParquetRecordReaderBase.java | 4 +- .../parquet/VectorizedColumnReader.java | 26 ++-- .../VectorizedParquetRecordReader.java | 11 +- .../parquet/ParquetFileFormat.scala | 51 +++----- .../parquet/ParquetReadSupport.scala | 4 +- .../parquet/ParquetRecordMaterializer.scala | 4 +- .../parquet/ParquetRowConverter.scala | 15 ++- .../parquet/ParquetSchemaConverter.scala | 92 ++++++++------- .../parquet/ParquetWriteSupport.scala | 48 ++++---- .../datasources/parquet/ParquetIOSuite.scala | 5 +- .../parquet/ParquetQuerySuite.scala | 22 ++++ .../parquet/ParquetSchemaSuite.scala | 111 +++++++----------- .../datasources/parquet/ParquetTest.scala | 2 +- .../spark/sql/internal/SQLConfSuite.scala | 28 +++++ 15 files changed, 249 insertions(+), 204 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a04f8778079de..831ef62d74c3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -285,8 +285,24 @@ object SQLConf { .booleanConf .createWithDefault(true) + object ParquetOutputTimestampType extends Enumeration { + val INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value + } + + val PARQUET_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.parquet.outputTimestampType") + .doc("Sets which Parquet timestamp type to use when Spark writes data to Parquet files. " + + "INT96 is a non-standard but commonly used timestamp type in Parquet. TIMESTAMP_MICROS " + + "is a standard timestamp type in Parquet, which stores number of microseconds from the " + + "Unix epoch. TIMESTAMP_MILLIS is also standard, but with millisecond precision, which " + + "means Spark has to truncate the microsecond portion of its timestamp value.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(ParquetOutputTimestampType.values.map(_.toString)) + .createWithDefault(ParquetOutputTimestampType.INT96.toString) + val PARQUET_INT64_AS_TIMESTAMP_MILLIS = buildConf("spark.sql.parquet.int64AsTimestampMillis") - .doc("When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + + .doc(s"(Deprecated since Spark 2.3, please set ${PARQUET_OUTPUT_TIMESTAMP_TYPE.key}.) " + + "When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + "extended type. In this mode, the microsecond portion of the timestamp value will be" + "truncated.") .booleanConf @@ -1143,6 +1159,18 @@ class SQLConf extends Serializable with Logging { def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) + def parquetOutputTimestampType: ParquetOutputTimestampType.Value = { + val isOutputTimestampTypeSet = settings.containsKey(PARQUET_OUTPUT_TIMESTAMP_TYPE.key) + if (!isOutputTimestampTypeSet && isParquetINT64AsTimestampMillis) { + // If PARQUET_OUTPUT_TIMESTAMP_TYPE is not set and PARQUET_INT64_AS_TIMESTAMP_MILLIS is set, + // respect PARQUET_INT64_AS_TIMESTAMP_MILLIS and use TIMESTAMP_MILLIS. Otherwise, + // PARQUET_OUTPUT_TIMESTAMP_TYPE has higher priority. + ParquetOutputTimestampType.TIMESTAMP_MILLIS + } else { + ParquetOutputTimestampType.withName(getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE)) + } + } + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 5a810cae1e184..80c2f491b48ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -197,8 +197,6 @@ protected void initialize(String path, List columns) throws IOException Configuration config = new Configuration(); config.set("spark.sql.parquet.binaryAsString", "false"); config.set("spark.sql.parquet.int96AsTimestamp", "false"); - config.set("spark.sql.parquet.writeLegacyFormat", "false"); - config.set("spark.sql.parquet.int64AsTimestampMillis", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); @@ -224,7 +222,7 @@ protected void initialize(String path, List columns) throws IOException this.requestedSchema = ParquetSchemaConverter.EMPTY_MESSAGE(); } } - this.sparkSchema = new ParquetSchemaConverter(config).convert(requestedSchema); + this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3c8d766ffad30..0f1f470dc597e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -26,6 +26,7 @@ import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; @@ -91,11 +92,15 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final OriginalType originalType; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) - throws IOException { + public VectorizedColumnReader( + ColumnDescriptor descriptor, + OriginalType originalType, + PageReader pageReader) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + this.originalType = originalType; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -158,12 +163,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - // Timestamp values encoded as INT64 can't be lazily decoded as we need to post process + // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && - column.dataType() != DataTypes.TimestampType) || + originalType != OriginalType.TIMESTAMP_MILLIS) || descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { @@ -253,21 +258,21 @@ private void decodeDictionaryIds( case INT64: if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { + DecimalType.is64BitDecimalType(column.dataType()) || + originalType == OriginalType.TIMESTAMP_MICROS) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else if (column.dataType() == DataTypes.TimestampType) { + } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, DateTimeUtils.fromMillis(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); } } - } - else { + } else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -377,10 +382,11 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { private void readLongBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { + DecimalType.is64BitDecimalType(column.dataType()) || + originalType == OriginalType.TIMESTAMP_MICROS) { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (column.dataType() == DataTypes.TimestampType) { + } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, DateTimeUtils.fromMillis(dataColumn.readLong())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 0cacf0c9c93a5..e827229dceef8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -165,8 +165,10 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch(MemoryMode memMode, StructType partitionColumns, - InternalRow partitionValues) { + public void initBatch( + MemoryMode memMode, + StructType partitionColumns, + InternalRow partitionValues) { StructType batchSchema = new StructType(); for (StructField f: sparkSchema.fields()) { batchSchema = batchSchema.add(f); @@ -281,11 +283,12 @@ private void checkEndOfRowGroup() throws IOException { + rowsReturned + " out of " + totalRowCount); } List columns = requestedSchema.getColumns(); + List types = requestedSchema.asGroupType().getFields(); columnReaders = new VectorizedColumnReader[columns.size()]; for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; - columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + columnReaders[i] = new VectorizedColumnReader( + columns.get(i), types.get(i).getOriginalType(), pages.getPageReader(columns.get(i))); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 61bd65dd48144..a48f8d517b6ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -111,23 +111,15 @@ class ParquetFileFormat // This metadata is only useful for detecting optional columns when pushdowning filters. ParquetWriteSupport.setSchema(dataSchema, conf) - // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) - // and `CatalystWriteSupport` (writing actual rows to Parquet files). - conf.set( - SQLConf.PARQUET_BINARY_AS_STRING.key, - sparkSession.sessionState.conf.isParquetBinaryAsString.toString) - - conf.set( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sparkSession.sessionState.conf.isParquetINT96AsTimestamp.toString) - + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. conf.set( SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) conf.set( - SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis.toString) + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -312,16 +304,13 @@ class ParquetFileFormat ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) - // Sets flags for `CatalystSchemaConverter` + // Sets flags for `ParquetToSparkSchemaConverter` hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, sparkSession.sessionState.conf.isParquetBinaryAsString) hadoopConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - hadoopConf.setBoolean( - SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) // Try to push down filters when filter push-down is enabled. val pushed = @@ -428,15 +417,9 @@ object ParquetFileFormat extends Logging { private[parquet] def readSchema( footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { - def parseParquetSchema(schema: MessageType): StructType = { - val converter = new ParquetSchemaConverter( - sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.writeLegacyParquetFormat, - sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) - - converter.convert(schema) - } + val converter = new ParquetToSparkSchemaConverter( + sparkSession.sessionState.conf.isParquetBinaryAsString, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) val seen = mutable.HashSet[String]() val finalSchemas: Seq[StructType] = footers.flatMap { footer => @@ -447,7 +430,7 @@ object ParquetFileFormat extends Logging { .get(ParquetReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. - Some(parseParquetSchema(metadata.getSchema)) + Some(converter.convert(metadata.getSchema)) } else if (!seen.contains(serializedSchema.get)) { seen += serializedSchema.get @@ -470,7 +453,7 @@ object ParquetFileFormat extends Logging { .map(_.asInstanceOf[StructType]) .getOrElse { // Falls back to Parquet schema if Spark SQL schema can't be parsed. - parseParquetSchema(metadata.getSchema) + converter.convert(metadata.getSchema) }) } else { None @@ -538,8 +521,6 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp - val writeTimestampInMillis = sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis - val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) // !! HACK ALERT !! @@ -579,13 +560,9 @@ object ParquetFileFormat extends Logging { serializedConf.value, fakeFileStatuses, ignoreCorruptFiles) // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` - val converter = - new ParquetSchemaConverter( - assumeBinaryIsString = assumeBinaryIsString, - assumeInt96IsTimestamp = assumeInt96IsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = writeTimestampInMillis) - + val converter = new ParquetToSparkSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp) if (footers.isEmpty) { Iterator.empty } else { @@ -625,7 +602,7 @@ object ParquetFileFormat extends Logging { * a [[StructType]] converted from the [[MessageType]] stored in this footer. */ def readSchemaFromFooter( - footer: Footer, converter: ParquetSchemaConverter): StructType = { + footer: Footer, converter: ParquetToSparkSchemaConverter): StructType = { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a6200..2854cb1bc0c25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -95,7 +95,7 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetSchemaConverter(conf)) + new ParquetToSparkSchemaConverter(conf)) } } @@ -270,7 +270,7 @@ private[parquet] object ParquetReadSupport { private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - val toParquet = new ParquetSchemaConverter(writeLegacyParquetFormat = false) + val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 4e49a0dac97c0..793755e9aaeb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -31,7 +31,9 @@ import org.apache.spark.sql.types.StructType * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters */ private[parquet] class ParquetRecordMaterializer( - parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetSchemaConverter) + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: ParquetToSparkSchemaConverter) extends RecordMaterializer[UnsafeRow] { private val rootConverter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd9766..10f6c3b4f15e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -27,7 +27,7 @@ import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -120,7 +120,7 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( - schemaConverter: ParquetSchemaConverter, + schemaConverter: ParquetToSparkSchemaConverter, parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) @@ -252,6 +252,13 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(value) + } + } + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { @@ -259,8 +266,8 @@ private[parquet] class ParquetRowConverter( } } - case TimestampType => - // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // INT96 timestamp doesn't have a logical type, here we check the physical type instead. + case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 => new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index cd384d17d0cda..c61be077d309f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -30,49 +30,31 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + /** - * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and - * vice versa. + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. * * Parquet format backwards-compatibility rules are respected when converting Parquet * [[MessageType]] schemas. * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * @constructor + * * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL - * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. This argument only affects Parquet read path. + * [[StringType]] fields. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL - * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which - * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` - * described in Parquet format spec. This argument only affects Parquet read path. - * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 - * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. - * When set to false, use standard format defined in parquet-format spec. This argument only - * affects Parquet write path. - * @param writeTimestampInMillis Whether to write timestamp values as INT64 annotated by logical - * type TIMESTAMP_MILLIS. - * + * [[TimestampType]] fields. */ -private[parquet] class ParquetSchemaConverter( +class ParquetToSparkSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, - writeTimestampInMillis: Boolean = SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get) { + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, - assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - writeTimestampInMillis = conf.isParquetINT64AsTimestampMillis) + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, - assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, - writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean, - writeTimestampInMillis = conf.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean) + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean) /** @@ -165,6 +147,7 @@ private[parquet] class ParquetSchemaConverter( case INT_64 | null => LongType case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() + case TIMESTAMP_MICROS => TimestampType case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() } @@ -310,6 +293,30 @@ private[parquet] class ParquetSchemaConverter( repeatedType.getName == s"${parentName}_tuple" } } +} + +/** + * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]]. + * + * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 + * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. + * When set to false, use standard format defined in parquet-format spec. This argument only + * affects Parquet write path. + * @param outputTimestampType which parquet timestamp type to use when writing. + */ +class SparkToParquetSchemaConverter( + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96) { + + def this(conf: SQLConf) = this( + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + outputTimestampType = conf.parquetOutputTimestampType) + + def this(conf: Configuration) = this( + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean, + outputTimestampType = SQLConf.ParquetOutputTimestampType.withName( + conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key))) /** * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. @@ -363,7 +370,9 @@ private[parquet] class ParquetSchemaConverter( case DateType => Types.primitive(INT32, repetition).as(DATE).named(field.name) - // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. + // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or + // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the + // behavior same as before. // // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond // timestamp in Impala for some historical reasons. It's not recommended to be used for any @@ -372,23 +381,18 @@ private[parquet] class ParquetSchemaConverter( // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`. // // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting - // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store - // a timestamp into a `Long`. This design decision is subject to change though, for example, - // we may resort to microsecond precision in the future. - // - // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's - // currently not implemented yet because parquet-mr 1.8.1 (the version we're currently using) - // hasn't implemented `TIMESTAMP_MICROS` yet, however it supports TIMESTAMP_MILLIS. We will - // encode timestamp values as TIMESTAMP_MILLIS annotating INT64 if - // 'spark.sql.parquet.int64AsTimestampMillis' is set. - // - // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. - - case TimestampType if writeTimestampInMillis => - Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) - + // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can + // store a timestamp into a `Long`. This design decision is subject to change though, for + // example, we may resort to nanosecond precision in the future. case TimestampType => - Types.primitive(INT96, repetition).named(field.name) + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + Types.primitive(INT96, repetition).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + } case BinaryType => Types.primitive(BINARY, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 63a8666f0d774..af4e1433c876f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -66,8 +66,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions private var writeLegacyParquetFormat: Boolean = _ - // Whether to write timestamp value with milliseconds precision. - private var writeTimestampInMillis: Boolean = _ + // Which parquet timestamp type to use when writing. + private var outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = _ // Reusable byte array used to write timestamps as Parquet INT96 values private val timestampBuffer = new Array[Byte](12) @@ -84,15 +84,15 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean } - this.writeTimestampInMillis = { - assert(configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key) != null) - configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean + this.outputTimestampType = { + val key = SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key + assert(configuration.get(key) != null) + SQLConf.ParquetOutputTimestampType.withName(configuration.get(key)) } - this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] - val messageType = new ParquetSchemaConverter(configuration).convert(schema) + val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava logInfo( @@ -163,25 +163,23 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit recordConsumer.addBinary( Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) - case TimestampType if writeTimestampInMillis => - (row: SpecializedGetters, ordinal: Int) => - val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) - recordConsumer.addLong(millis) - case TimestampType => - (row: SpecializedGetters, ordinal: Int) => { - // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it - // Currently we only support timestamps stored as INT96, which is compatible with Hive - // and Impala. However, INT96 is to be deprecated. We plan to support `TIMESTAMP_MICROS` - // defined in the parquet-format spec. But up until writing, the most recent parquet-mr - // version (1.8.1) hasn't implemented it yet. - - // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond - // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) - val buf = ByteBuffer.wrap(timestampBuffer) - buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) - recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + outputTimestampType match { + case SQLConf.ParquetOutputTimestampType.INT96 => + (row: SpecializedGetters, ordinal: Int) => + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val buf = ByteBuffer.wrap(timestampBuffer) + buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) + recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => + (row: SpecializedGetters, ordinal: Int) => + recordConsumer.addLong(row.getLong(ordinal)) + + case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => + (row: SpecializedGetters, ordinal: Int) => + val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) + recordConsumer.addLong(millis) } case BinaryType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index d76990b482db2..633cfde6ab941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -110,12 +110,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { | required binary h(DECIMAL(32,0)); | required fixed_len_byte_array(32) i(DECIMAL(32,0)); | required int64 j(TIMESTAMP_MILLIS); + | required int64 k(TIMESTAMP_MICROS); |} """.stripMargin) val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0), - TimestampType) + TimestampType, TimestampType) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -380,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val expectedSchema = new ParquetSchemaConverter().convert(schema) + val expectedSchema = new SparkToParquetSchemaConverter().convert(schema) val actualSchema = readFooter(path, hadoopConf).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e822e40b146ee..4c8c9ef6e0432 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -235,6 +235,28 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-10365 timestamp written and read as INT64 - TIMESTAMP_MICROS") { + val data = (1 to 10).map { i => + val ts = new java.sql.Timestamp(i) + ts.setNanos(2000) + Row(i, ts) + } + val schema = StructType(List(StructField("d", IntegerType, false), + StructField("time", TimestampType, false)).toArray) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MICROS") { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index ce992674d719f..2cd2a600f2b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -24,6 +24,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,14 +53,10 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { sqlSchema: StructType, parquetSchema: String, binaryAsString: Boolean, - int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { - val converter = new ParquetSchemaConverter( + int96AsTimestamp: Boolean): Unit = { + val converter = new ParquetToSparkSchemaConverter( assumeBinaryIsString = binaryAsString, - assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = int64AsTimestampMillis) + assumeInt96IsTimestamp = int96AsTimestamp) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -77,15 +74,12 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean, - int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { - val converter = new ParquetSchemaConverter( - assumeBinaryIsString = binaryAsString, - assumeInt96IsTimestamp = int96AsTimestamp, + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96): Unit = { + val converter = new SparkToParquetSchemaConverter( writeLegacyParquetFormat = writeLegacyParquetFormat, - writeTimestampInMillis = int64AsTimestampMillis) + outputTimestampType = outputTimestampType) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -102,25 +96,22 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { binaryAsString: Boolean, int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean, - int64AsTimestampMillis: Boolean = false): Unit = { + outputTimestampType: SQLConf.ParquetOutputTimestampType.Value = + SQLConf.ParquetOutputTimestampType.INT96): Unit = { testCatalystToParquet( testName, sqlSchema, parquetSchema, - binaryAsString, - int96AsTimestamp, writeLegacyParquetFormat, - int64AsTimestampMillis) + outputTimestampType) testParquetToCatalyst( testName, sqlSchema, parquetSchema, binaryAsString, - int96AsTimestamp, - writeLegacyParquetFormat, - int64AsTimestampMillis) + int96AsTimestamp) } } @@ -411,8 +402,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with nullable element type - 2", @@ -430,8 +420,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -446,8 +435,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 2", @@ -462,8 +450,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 3", @@ -476,8 +463,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 4", @@ -500,8 +486,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", @@ -522,8 +507,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", @@ -544,8 +528,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 7 - " + @@ -557,8 +540,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 8 - " + @@ -580,8 +562,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST @@ -602,8 +583,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -621,8 +600,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) testCatalystToParquet( @@ -640,8 +617,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -657,8 +632,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) // ==================================================== @@ -682,8 +655,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 2", @@ -702,8 +674,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", @@ -722,8 +693,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -742,8 +712,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 2", @@ -762,8 +731,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", @@ -782,8 +750,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin, binaryAsString = true, - int96AsTimestamp = true, - writeLegacyParquetFormat = true) + int96AsTimestamp = true) // ==================================================== // Tests for converting Catalyst MapType to Parquet Map @@ -805,8 +772,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -825,8 +790,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) testCatalystToParquet( @@ -845,8 +808,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = false) testCatalystToParquet( @@ -865,8 +826,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - binaryAsString = true, - int96AsTimestamp = true, writeLegacyParquetFormat = true) // ================================= @@ -982,7 +941,19 @@ class ParquetSchemaSuite extends ParquetSchemaTest { binaryAsString = true, int96AsTimestamp = false, writeLegacyParquetFormat = true, - int64AsTimestampMillis = true) + outputTimestampType = SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS) + + testSchema( + "Timestamp written and read as INT64 with TIMESTAMP_MICROS", + StructType(Seq(StructField("f1", TimestampType))), + """message root { + | optional INT64 f1 (TIMESTAMP_MICROS); + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = false, + writeLegacyParquetFormat = true, + outputTimestampType = SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) private def testSchemaClipping( testName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 85efca3c4b24d..f05f5722af51a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -124,7 +124,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def writeMetadata( schema: StructType, path: Path, configuration: Configuration): Unit = { - val parquetSchema = new ParquetSchemaConverter().convert(schema) + val parquetSchema = new SparkToParquetSchemaConverter().convert(schema) val extraMetadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schema.json).asJava val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 205c303b6cc4b..f9d75fc1788db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -281,4 +281,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { assert(null == spark.conf.get("spark.sql.nonexistent", null)) assert("" == spark.conf.get("spark.sql.nonexistent", "")) } + + test("SPARK-10365: PARQUET_OUTPUT_TIMESTAMP_TYPE") { + spark.sessionState.conf.clear() + + // check default value + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.INT96) + + // PARQUET_INT64_AS_TIMESTAMP_MILLIS should be respected. + spark.sessionState.conf.setConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS, true) + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS) + + // PARQUET_OUTPUT_TIMESTAMP_TYPE has higher priority over PARQUET_INT64_AS_TIMESTAMP_MILLIS + spark.sessionState.conf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "timestamp_micros") + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) + spark.sessionState.conf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "int96") + assert(spark.sessionState.conf.parquetOutputTimestampType == + SQLConf.ParquetOutputTimestampType.INT96) + + // test invalid conf value + intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, "invalid") + } + + spark.sessionState.conf.clear() + } } From b3f9dbf48ec0938ff5c98833bb6b6855c620ef57 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Sun, 12 Nov 2017 11:21:23 -0800 Subject: [PATCH 232/712] [SPARK-19606][MESOS] Support constraints in spark-dispatcher ## What changes were proposed in this pull request? A discussed in SPARK-19606, the addition of a new config property named "spark.mesos.constraints.driver" for constraining drivers running on a Mesos cluster ## How was this patch tested? Corresponding unit test added also tested locally on a Mesos cluster Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Paul Mackles Closes #19543 from pmackles/SPARK-19606. --- docs/running-on-mesos.md | 17 ++++++- .../apache/spark/deploy/mesos/config.scala | 7 +++ .../cluster/mesos/MesosClusterScheduler.scala | 15 ++++-- .../mesos/MesosClusterSchedulerSuite.scala | 47 +++++++++++++++++++ .../spark/scheduler/cluster/mesos/Utils.scala | 31 ++++++++---- 5 files changed, 100 insertions(+), 17 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 7a443ffddc5f0..19ec7c1e0aeee 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -263,7 +263,10 @@ resource offers will be accepted. conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will +be checked to see if they meet both these constraints and only then will be accepted to start new executors. + +To constrain where driver tasks are run, use `spark.mesos.driver.constraints` # Mesos Docker Support @@ -447,7 +450,9 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + applies only to executors. Refer to Mesos + Attributes & Resources for more information on attributes.
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • @@ -457,6 +462,14 @@ See the [configuration page](configuration.html) for information on Spark config
    + + spark.mesos.driver.constraints + (none) + + Same as spark.mesos.constraints except applied to drivers when launched through the dispatcher. By default, + all offers with sufficient resources will be accepted. + + spark.mesos.containerizer docker diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 821534eb4fc38..d134847dc74d2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -122,4 +122,11 @@ package object config { "Example: key1:val1,key2:val2") .stringConf .createOptional + + private[spark] val DRIVER_CONSTRAINTS = + ConfigBuilder("spark.mesos.driver.constraints") + .doc("Attribute based constraints on mesos resource offers. Applied by the dispatcher " + + "when launching drivers. Default is to accept all offers with sufficient resources.") + .stringConf + .createWithDefault("") } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index de846c85d53a6..c41283e4a3e39 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -556,9 +556,10 @@ private[spark] class MesosClusterScheduler( private class ResourceOffer( val offer: Offer, - var remainingResources: JList[Resource]) { + var remainingResources: JList[Resource], + var attributes: JList[Attribute]) { override def toString(): String = { - s"Offer id: ${offer.getId}, resources: ${remainingResources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}, attributes: ${attributes}" } } @@ -601,10 +602,14 @@ private[spark] class MesosClusterScheduler( for (submission <- candidates) { val driverCpu = submission.cores val driverMem = submission.mem - logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val driverConstraints = + parseConstraintString(submission.conf.get(config.DRIVER_CONSTRAINTS)) + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem, " + + s"driverConstraints: $driverConstraints") val offerOption = currentOffers.find { offer => getResource(offer.remainingResources, "cpus") >= driverCpu && - getResource(offer.remainingResources, "mem") >= driverMem + getResource(offer.remainingResources, "mem") >= driverMem && + matchesAttributeRequirements(driverConstraints, toAttributeMap(offer.attributes)) } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + @@ -652,7 +657,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - offer => new ResourceOffer(offer, offer.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList, offer.getAttributesList) }.toList stateLock.synchronized { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 77acee608f25f..e534b9d7e3ed9 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -254,6 +254,53 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2") } + test("accept/decline offers with driver constraints") { + setScheduler() + + val mem = 1000 + val cpu = 1 + val s2Attributes = List(Utils.createTextAttribute("c1", "a")) + val s3Attributes = List( + Utils.createTextAttribute("c1", "a"), + Utils.createTextAttribute("c2", "b")) + val offers = List( + Utils.createOffer("o1", "s1", mem, cpu, None, 0), + Utils.createOffer("o2", "s2", mem, cpu, None, 0, s2Attributes), + Utils.createOffer("o3", "s3", mem, cpu, None, 0, s3Attributes)) + + def submitDriver(driverConstraints: String): Unit = { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + config.DRIVER_CONSTRAINTS.key -> driverConstraints), + "s1", + new Date())) + assert(response.success) + } + + submitDriver("c1:x") + scheduler.resourceOffers(driver, offers.asJava) + offers.foreach(o => Utils.verifyTaskNotLaunched(driver, o.getId.getValue)) + + submitDriver("c1:y;c2:z") + scheduler.resourceOffers(driver, offers.asJava) + offers.foreach(o => Utils.verifyTaskNotLaunched(driver, o.getId.getValue)) + + submitDriver("") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o1") + + submitDriver("c1:a") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o2") + + submitDriver("c1:a;c2:b") + scheduler.resourceOffers(driver, offers.asJava) + Utils.verifyTaskLaunched(driver, "o3") + } + test("supports spark.mesos.driver.labels") { setScheduler() diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 5636ac52bd4a7..c9f47471cd75e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -43,12 +43,13 @@ object Utils { .build() def createOffer( - offerId: String, - slaveId: String, - mem: Int, - cpus: Int, - ports: Option[(Long, Long)] = None, - gpus: Int = 0): Offer = { + offerId: String, + slaveId: String, + mem: Int, + cpus: Int, + ports: Option[(Long, Long)] = None, + gpus: Int = 0, + attributes: List[Attribute] = List.empty): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -63,7 +64,7 @@ object Utils { .setName("ports") .setType(Value.Type.RANGES) .setRanges(Ranges.newBuilder().addRange(MesosRange.newBuilder() - .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) + .setBegin(resourcePorts._1).setEnd(resourcePorts._2).build())) } if (gpus > 0) { builder.addResourcesBuilder() @@ -73,9 +74,10 @@ object Utils { } builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() - .setValue("f1")) + .setValue("f1")) .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) .setHostname(s"host${slaveId}") + .addAllAttributes(attributes.asJava) .build() } @@ -125,7 +127,7 @@ object Utils { .getVariablesList .asScala assert(envVars - .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars val variableOne = envVars.filter(_.getName == "SECRET_ENV_KEY").head assert(variableOne.getSecret.isInitialized) assert(variableOne.getSecret.getType == Secret.Type.REFERENCE) @@ -154,7 +156,7 @@ object Utils { .getVariablesList .asScala assert(envVars - .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars + .count(!_.getName.startsWith("SPARK_")) == 2) // user-defined secret env vars val variableOne = envVars.filter(_.getName == "USER").head assert(variableOne.getSecret.isInitialized) assert(variableOne.getSecret.getType == Secret.Type.VALUE) @@ -212,4 +214,13 @@ object Utils { assert(secretVolTwo.getSource.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes)) } + + def createTextAttribute(name: String, value: String): Attribute = { + Attribute.newBuilder() + .setName(name) + .setType(Value.Type.TEXT) + .setText(Value.Text.newBuilder().setValue(value)) + .build() + } } + From 9bf696dbece6b1993880efba24a6d32c54c4d11c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 12 Nov 2017 22:44:47 +0100 Subject: [PATCH 233/712] [SPARK-21720][SQL] Fix 64KB JVM bytecode limit problem with AND or OR ## What changes were proposed in this pull request? This PR changes `AND` or `OR` code generation to place condition and then expressions' generated code into separated methods if these size could be large. When the method is newly generated, variables for `isNull` and `value` are declared as an instance variable to pass these values (e.g. `isNull1409` and `value1409`) to the callers of the generated method. This PR resolved two cases: * large code size of left expression * large code size of right expression ## How was this patch tested? Added a new test case into `CodeGenerationSuite` Author: Kazuaki Ishizaki Closes #18972 from kiszk/SPARK-21720. --- .../expressions/codegen/CodeGenerator.scala | 30 +++++++ .../expressions/conditionalExpressions.scala | 29 +------ .../sql/catalyst/expressions/predicates.scala | 82 ++++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 39 +++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- 5 files changed, 157 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 98eda2a1ba92c..3dc3f8e4adac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -908,6 +908,36 @@ class CodegenContext { } } + /** + * Wrap the generated code of expression, which was created from a row object in INPUT_ROW, + * by a function. ev.isNull and ev.value are passed by global variables + * + * @param ev the code to evaluate expressions. + * @param dataType the data type of ev.value. + * @param baseFuncName the split function name base. + */ + def createAndAddFunction( + ev: ExprCode, + dataType: DataType, + baseFuncName: String): (String, String, String) = { + val globalIsNull = freshName("isNull") + addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalValue = freshName("value") + addMutableState(javaType(dataType), globalValue, + s"$globalValue = ${defaultValue(dataType)};") + val funcName = freshName(baseFuncName) + val funcBody = + s""" + |private void $funcName(InternalRow ${INPUT_ROW}) { + | ${ev.code.trim} + | $globalIsNull = ${ev.isNull}; + | $globalValue = ${ev.value}; + |} + """.stripMargin + val fullFuncName = addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNull, globalValue) + } + /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d95b59d5ec423..c41a10c7b0f87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -72,11 +72,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val (condFuncName, condGlobalIsNull, condGlobalValue) = - createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") + ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr") val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = - createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") + ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr") val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = - createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") + ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr") s""" $condFuncName(${ctx.INPUT_ROW}); boolean ${ev.isNull} = false; @@ -112,29 +112,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ev.copy(code = generatedCode) } - private def createAndAddFunction( - ctx: CodegenContext, - ev: ExprCode, - dataType: DataType, - baseFuncName: String): (String, String, String) = { - val globalIsNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") - val globalValue = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(dataType), globalValue, - s"$globalValue = ${ctx.defaultValue(dataType)};") - val funcName = ctx.freshName(baseFuncName) - val funcBody = - s""" - |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; - |} - """.stripMargin - val fullFuncName = ctx.addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) - } - override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index efcd45fad779c..61df5e053a374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -368,7 +368,46 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - if (!left.nullable && !right.nullable) { + + // place generated code of eval1 and eval2 in separate methods if their code combined is large + val combinedLength = eval1.code.length + eval2.code.length + if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = + ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") + val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = + ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") + if (!left.nullable && !right.nullable) { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.value} = false; + if (${eval1GlobalValue}) { + $eval2FuncName(${ctx.INPUT_ROW}); + ${ev.value} = ${eval2GlobalValue}; + } + """ + ev.copy(code = generatedCode, isNull = "false") + } else { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + boolean ${ev.value} = false; + if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) { + } else { + $eval2FuncName(${ctx.INPUT_ROW}); + if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) { + } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { + ${ev.value} = true; + } else { + ${ev.isNull} = true; + } + } + """ + ev.copy(code = generatedCode) + } + } else if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; @@ -431,7 +470,46 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - if (!left.nullable && !right.nullable) { + + // place generated code of eval1 and eval2 in separate methods if their code combined is large + val combinedLength = eval1.code.length + eval2.code.length + if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = + ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") + val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = + ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") + if (!left.nullable && !right.nullable) { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.value} = true; + if (!${eval1GlobalValue}) { + $eval2FuncName(${ctx.INPUT_ROW}); + ${ev.value} = ${eval2GlobalValue}; + } + """ + ev.copy(code = generatedCode, isNull = "false") + } else { + val generatedCode = s""" + $eval1FuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + boolean ${ev.value} = true; + if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) { + } else { + $eval2FuncName(${ctx.INPUT_ROW}); + if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) { + } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } + } + """ + ev.copy(code = generatedCode) + } + } else if (!left.nullable && !right.nullable) { ev.isNull = "false" ev.copy(code = s""" ${eval1.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e6f7b65e7e72..8f6289f00571c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -341,4 +341,43 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // should not throw exception projection(row) } + + test("SPARK-21720: split large predications into blocks due to JVM code size limit") { + val length = 600 + + val input = new GenericInternalRow(length) + val utf8Str = UTF8String.fromString(s"abc") + for (i <- 0 until length) { + input.update(i, utf8Str) + } + + var exprOr: Expression = Literal(false) + for (i <- 0 until length) { + exprOr = Or(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprOr) + } + + val planOr = GenerateMutableProjection.generate(Seq(exprOr)) + val actualOr = planOr(input).toSeq(Seq(exprOr.dataType)) + assert(actualOr.length == 1) + val expectedOr = false + + if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) { + fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr") + } + + var exprAnd: Expression = Literal(true) + for (i <- 0 until length) { + exprAnd = And(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprAnd) + } + + val planAnd = GenerateMutableProjection.generate(Seq(exprAnd)) + val actualAnd = planAnd(input).toSeq(Seq(exprAnd.dataType)) + assert(actualAnd.length == 1) + val expectedAnd = false + + if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) { + fail( + s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 31bfa77e76329..644e72c893ceb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2097,7 +2097,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .count } - testQuietly("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + // The fix of SPARK-21720 avoid an exception regarding JVM code size limit + // TODO: When we make a threshold of splitting statements (1024) configurable, + // we will re-enable this with max threshold to cause an exception + // See https://github.com/apache/spark/pull/18972/files#r150223463 + ignore("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { val N = 400 val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) From 3d90b2cb384affe8ceac9398615e9e21b8c8e0b0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Nov 2017 14:37:20 -0800 Subject: [PATCH 234/712] [SPARK-21693][R][ML] Reduce max iterations in Linear SVM test in R to speed up AppVeyor build ## What changes were proposed in this pull request? This PR proposes to reduce max iteration in Linear SVM test in SparkR. This particular test elapses roughly 5 mins on my Mac and over 20 mins on Windows. The root cause appears, it triggers 2500ish jobs by the default 100 max iterations. In Linux, `daemon.R` is forked but on Windows another process is launched, which is extremely slow. So, given my observation, there are many processes (not forked) ran on Windows, which makes the differences of elapsed time. After reducing the max iteration to 10, the total jobs in this single test is reduced to 550ish. After reducing the max iteration to 5, the total jobs in this single test is reduced to 360ish. ## How was this patch tested? Manually tested the elapsed times. Author: hyukjinkwon Closes #19722 from HyukjinKwon/SPARK-21693-test. --- R/pkg/tests/fulltests/test_mllib_classification.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index a4d0397236d17..ad47717ddc12f 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -66,7 +66,7 @@ test_that("spark.svmLinear", { feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) data <- as.data.frame(cbind(label, feature)) df <- createDataFrame(data) - model <- spark.svmLinear(df, label ~ feature, regParam = 0.1) + model <- spark.svmLinear(df, label ~ feature, regParam = 0.1, maxIter = 5) prediction <- collect(select(predict(model, df), "prediction")) expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) @@ -77,10 +77,11 @@ test_that("spark.svmLinear", { trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) traindf <- as.DataFrame(data[trainidxs, ]) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) - model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, maxIter = 5) predictions <- predict(model, testdf) expect_error(collect(predictions)) - model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip") + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, + handleInvalid = "skip", maxIter = 5) predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "list") From 209b9361ac8a4410ff797cff1115e1888e2f7e66 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 13 Nov 2017 13:16:01 +0900 Subject: [PATCH 235/712] [SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas ## What changes were proposed in this pull request? This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default. ## How was this patch tested? Added new unit test to create DataFrame with and without the optimization enabled, then compare results. Author: Bryan Cutler Author: Takuya UESHIN Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791. --- python/pyspark/context.py | 28 +++--- python/pyspark/java_gateway.py | 1 + python/pyspark/serializers.py | 10 ++- python/pyspark/sql/session.py | 88 ++++++++++++++---- python/pyspark/sql/tests.py | 89 ++++++++++++++++--- python/pyspark/sql/types.py | 49 ++++++++++ .../spark/sql/api/python/PythonSQLUtils.scala | 18 ++++ .../sql/execution/arrow/ArrowConverters.scala | 14 +++ 8 files changed, 254 insertions(+), 43 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a33f6dcf31fc0..24905f1c97b21 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -475,24 +475,30 @@ def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). + + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + jrdd = self._serialize_to_jvm(c, numSlices, serializer) + return RDD(jrdd, self, serializer) + + def _serialize_to_jvm(self, data, parallelism, serializer): + """ + Calling the Java parallelize() method with an ArrayList is too slow, + because it sends O(n) Py4J commands. As an alternative, serialized + objects are written to a file and loaded through textFile(). + """ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) + serializer.dump_stream(data, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + return readRDDFromFile(self._jsc, tempFile.name, parallelism) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) - return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): """ diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3c783ae541a1f..3e704fe9bf6ec 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -121,6 +121,7 @@ def killChild(): java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d7979f095da76..e0afdafbfcd62 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,6 +214,13 @@ def __repr__(self): def _create_batch(series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] @@ -229,7 +236,8 @@ def cast_series(s, t): # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 return _check_series_convert_timestamps_internal(s.fillna(0))\ .values.astype('datetime64[us]', copy=False) - elif t == pa.date32(): + # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 + elif t is not None and t == pa.date32(): # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 return s.dt.date elif t is None or s.dtype == t.to_pandas_dtype(): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d1d0b8b8fe5d9..589365b083012 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -25,7 +25,7 @@ basestring = unicode = str xrange = range else: - from itertools import imap as map + from itertools import izip as zip, imap as map from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix @@ -417,12 +417,12 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema - def _get_numpy_record_dtypes(self, rec): + def _get_numpy_record_dtype(self, rec): """ Used when converting a pandas.DataFrame to Spark using to_records(), this will correct - the dtypes of records so they can be properly loaded into Spark. - :param rec: a numpy record to check dtypes - :return corrected dtypes for a numpy.record or None if no correction needed + the dtypes of fields in a record so they can be properly loaded into Spark. + :param rec: a numpy record to check field dtypes + :return corrected dtype for a numpy.record or None if no correction needed """ import numpy as np cur_dtypes = rec.dtype @@ -438,28 +438,70 @@ def _get_numpy_record_dtypes(self, rec): curr_type = 'datetime64[us]' has_rec_fix = True record_type_list.append((str(col_names[i]), curr_type)) - return record_type_list if has_rec_fix else None + return np.dtype(record_type_list) if has_rec_fix else None - def _convert_from_pandas(self, pdf, schema): + def _convert_from_pandas(self, pdf): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame - :return tuple of list of records and schema + :return list of records """ - # If no schema supplied by user then get the names of columns only - if schema is None: - schema = [str(x) for x in pdf.columns] # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) # Check if any columns need to be fixed for Spark to infer properly if len(np_records) > 0: - record_type_list = self._get_numpy_record_dtypes(np_records[0]) - if record_type_list is not None: - return [r.astype(record_type_list).tolist() for r in np_records], schema + record_dtype = self._get_numpy_record_dtype(np_records[0]) + if record_dtype is not None: + return [r.astype(record_dtype).tolist() for r in np_records] # Convert list of numpy records to python lists - return [r.tolist() for r in np_records], schema + return [r.tolist() for r in np_records] + + def _create_from_pandas_with_arrow(self, pdf, schema): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + + # Determine arrow types to coerce data when creating batches + if isinstance(schema, StructType): + arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] + elif isinstance(schema, DataType): + raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + else: + # Any timestamps must be coerced to be compatible with Spark + arrow_types = [to_arrow_type(TimestampType()) + if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None + for t in pdf.dtypes] + + # Slice the DataFrame to be batched + step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) + + # Create Arrow record batches + batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + for pdf_slice in pdf_slices] + + # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) + if isinstance(schema, (list, tuple)): + struct = from_arrow_schema(batches[0].schema) + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + # Create the Spark DataFrame directly from the Arrow data and schema + jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) + jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( + jrdd, schema.json(), self._wrapped._jsqlContext) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df @since(2.0) @ignore_unicode_prefix @@ -557,7 +599,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - data, schema = self._convert_from_pandas(data, schema) + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in data.columns] + + if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ + and len(data) > 0: + try: + return self._create_from_pandas_with_arrow(data, schema) + except Exception as e: + warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) + # Fallback to create DataFrame without arrow if raise some exception + data = self._convert_from_pandas(data) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True @@ -576,7 +630,7 @@ def prepare(obj): verify_func(obj) return obj, else: - if isinstance(schema, list): + if isinstance(schema, (list, tuple)): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4819f629c5310..6356d938db26a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3127,9 +3127,9 @@ def setUpClass(cls): StructField("5_double_t", DoubleType(), True), StructField("6_date_t", DateType(), True), StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3145,6 +3145,17 @@ def assertFramesEqual(self, df_with_arrow, df_without): ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + def create_pandas_data_frame(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + return pd.DataFrame(data=data_dict) + def test_unsupported_datatype(self): schema = StructType([StructField("decimal", DecimalType(), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) @@ -3161,21 +3172,15 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + try: + pdf = df.toPandas() + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) + pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) @@ -3187,6 +3192,62 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") + try: + df_no_arrow = self.spark.createDataFrame(pdf) + finally: + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf) + self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + + def test_createDataFrame_with_schema(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(pdf, schema=self.schema) + self.assertEquals(self.schema, df.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_createDataFrame_with_incorrect_schema(self): + pdf = self.create_pandas_data_frame() + wrong_schema = StructType(list(reversed(self.schema))) + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + self.spark.createDataFrame(pdf, schema=wrong_schema) + + def test_createDataFrame_with_names(self): + pdf = self.create_pandas_data_frame() + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) + self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + + def test_createDataFrame_with_single_data_type(self): + import pandas as pd + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") + + def test_createDataFrame_does_not_modify_input(self): + # Some series get converted for Spark to consume, this makes sure input is unchanged + pdf = self.create_pandas_data_frame() + # Use a nanosecond value to make sure it is not truncated + pdf.ix[0, '7_timestamp_t'] = 1 + # Integers with nulls will get NaNs filled with 0 and will be casted + pdf.ix[1, '2_int_t'] = None + pdf_copy = pdf.copy(deep=True) + self.spark.createDataFrame(pdf, schema=self.schema) + self.assertTrue(pdf.equals(pdf_copy)) + + def test_schema_conversion_roundtrip(self): + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + arrow_schema = to_arrow_schema(self.schema) + schema_rt = from_arrow_schema(arrow_schema) + self.assertEquals(self.schema, schema_rt) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7dd8fa04160e0..fe62f60dd6d0e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1629,6 +1629,55 @@ def to_arrow_type(dt): return arrow_type +def to_arrow_schema(schema): + """ Convert a schema from Spark to Arrow + """ + import pyarrow as pa + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in schema] + return pa.schema(fields) + + +def from_arrow_type(at): + """ Convert pyarrow type to Spark data type. + """ + # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type + import pyarrow as pa + if at == pa.bool_(): + spark_type = BooleanType() + elif at == pa.int8(): + spark_type = ByteType() + elif at == pa.int16(): + spark_type = ShortType() + elif at == pa.int32(): + spark_type = IntegerType() + elif at == pa.int64(): + spark_type = LongType() + elif at == pa.float32(): + spark_type = FloatType() + elif at == pa.float64(): + spark_type = DoubleType() + elif type(at) == pa.DecimalType: + spark_type = DecimalType(precision=at.precision, scale=at.scale) + elif at == pa.string(): + spark_type = StringType() + elif at == pa.date32(): + spark_type = DateType() + elif type(at) == pa.TimestampType: + spark_type = TimestampType() + else: + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + return spark_type + + +def from_arrow_schema(arrow_schema): + """ Convert schema from Arrow to Spark. + """ + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in arrow_schema]) + + def _check_dataframe_localize_timestamps(pdf): """ Convert timezone aware timestamps to timezone-naive in local time diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 4d5ce0bb60c0b..b33760b1edbc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.api.python +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.DataType private[sql] object PythonSQLUtils { @@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils { def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray } + + /** + * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * + * @param payloadRDD A JavaRDD of ArrowPayloads. + * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param sqlContext The active [[SQLContext]]. + * @return The converted [[DataFrame]]. + */ + def arrowPayloadToDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 05ea1517fcac9..3cafb344ef553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ @@ -204,4 +206,16 @@ private[sql] object ArrowConverters { reader.close() } } + + private[sql] def toDataFrame( + payloadRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + val rdd = payloadRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + } + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + sqlContext.internalCreateDataFrame(rdd, schema) + } } From 176ae4d53e0269cfc2cfa62d3a2991e28f5a9182 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 13 Nov 2017 06:19:13 -0600 Subject: [PATCH 236/712] [MINOR][CORE] Using bufferedInputStream for dataDeserializeStream ## What changes were proposed in this pull request? Small fix. Using bufferedInputStream for dataDeserializeStream. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19735 from ConeyLiu/smallfix. --- .../scala/org/apache/spark/serializer/SerializerManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 311383e7ea2bd..1d4b05caaa143 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -206,7 +206,7 @@ private[spark] class SerializerManager( val autoPick = !blockId.isInstanceOf[StreamBlockId] getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, inputStream)) + .deserializeStream(wrapForCompression(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] } } From f7534b37ee91be14e511ab29259c3f83c7ad50af Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 13 Nov 2017 13:10:13 -0800 Subject: [PATCH 237/712] [SPARK-22487][SQL][FOLLOWUP] still keep spark.sql.hive.version ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/19712 , adds back the `spark.sql.hive.version`, so that if users try to read this config, they can still get a default value instead of null. ## How was this patch tested? N/A Author: Wenchen Fan Closes #19719 from cloud-fan/minor. --- .../sql/hive/thriftserver/SparkSQLEnv.scala | 1 + .../thriftserver/SparkSQLSessionManager.scala | 1 + .../HiveThriftServer2Suites.scala | 23 +++++++++++++++---- .../org/apache/spark/sql/hive/HiveUtils.scala | 8 +++++++ 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 5db93b26f550e..6b19f971b73bb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -55,6 +55,7 @@ private[hive] object SparkSQLEnv extends Logging { metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + sparkSession.conf.set(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 00920c297d493..48c0ebef3e0ce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -77,6 +77,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: } else { sqlContext.newSession() } + ctx.setConf(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion) if (sessionConf != null && sessionConf.containsKey("use:database")) { ctx.sql(s"use ${sessionConf.get("use:database")}") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b80596f55bdea..7289da71a3365 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -155,9 +155,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(1) === "spark.sql.hive.version") assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } @@ -521,7 +521,20 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { conf += resultSet.getString(1) -> resultSet.getString(2) } - assert(conf.get("spark.sql.hive.metastore.version") === Some("1.2.1")) + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } + + test("Checks Hive version via SET") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) } } @@ -708,9 +721,9 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { test("Checks Hive version") { withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.metastore.version") + val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === "spark.sql.hive.metastore.version") + assert(resultSet.getString(1) === "spark.sql.hive.version") assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index d8e08f1f6df50..f5e6720f6a510 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -66,6 +66,14 @@ private[spark] object HiveUtils extends Logging { .stringConf .createWithDefault(builtinHiveVersion) + // A fake config which is only here for backward compatibility reasons. This config has no effect + // to Spark, just for reporting the builtin Hive version of Spark to existing applications that + // already rely on this config. + val FAKE_HIVE_VERSION = buildConf("spark.sql.hive.version") + .doc(s"deprecated, please use ${HIVE_METASTORE_VERSION.key} to get the Hive version in Spark.") + .stringConf + .createWithDefault(builtinHiveVersion) + val HIVE_METASTORE_JARS = buildConf("spark.sql.hive.metastore.jars") .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. From c8b7f97b8a58bf4a9f6e3a07dd6e5b0f646d8d99 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Nov 2017 08:28:13 +0900 Subject: [PATCH 238/712] [SPARK-22377][BUILD] Use /usr/sbin/lsof if lsof does not exists in release-build.sh ## What changes were proposed in this pull request? This PR proposes to use `/usr/sbin/lsof` if `lsof` is missing in the path to fix nightly snapshot jenkins jobs. Please refer https://github.com/apache/spark/pull/19359#issuecomment-340139557: > Looks like some of the snapshot builds are having lsof issues: > > https://amplab.cs.berkeley.edu/jenkins/view/Spark%20Packaging/job/spark-branch-2.1-maven-snapshots/182/console > >https://amplab.cs.berkeley.edu/jenkins/view/Spark%20Packaging/job/spark-branch-2.2-maven-snapshots/134/console > >spark-build/dev/create-release/release-build.sh: line 344: lsof: command not found >usage: kill [ -s signal | -p ] [ -a ] pid ... >kill -l [ signal ] Up to my knowledge, the full path of `lsof` is required for non-root user in few OSs. ## How was this patch tested? Manually tested as below: ```bash #!/usr/bin/env bash LSOF=lsof if ! hash $LSOF 2>/dev/null; then echo "a" LSOF=/usr/sbin/lsof fi $LSOF -P | grep "a" ``` Author: hyukjinkwon Closes #19695 from HyukjinKwon/SPARK-22377. --- dev/create-release/release-build.sh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 7e8d5c7075195..5b43f9bab7505 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -130,6 +130,13 @@ else fi fi +# This is a band-aid fix to avoid the failure of Maven nightly snapshot in some Jenkins +# machines by explicitly calling /usr/sbin/lsof. Please see SPARK-22377 and the discussion +# in its pull request. +LSOF=lsof +if ! hash $LSOF 2>/dev/null; then + LSOF=/usr/sbin/lsof +fi if [ -z "$SPARK_PACKAGE_VERSION" ]; then SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" @@ -345,7 +352,7 @@ if [[ "$1" == "publish-snapshot" ]]; then # -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process - lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill rm $tmp_settings cd .. @@ -382,7 +389,7 @@ if [[ "$1" == "publish-release" ]]; then # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process - lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill #./dev/change-scala-version.sh 2.11 From d8741b2b0fe8b8da74f120859e969326fb170629 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 13 Nov 2017 17:00:51 -0800 Subject: [PATCH 239/712] [SPARK-21911][ML][FOLLOW-UP] Fix doc for parallel ML Tuning in PySpark ## What changes were proposed in this pull request? Fix doc issue mentioned here: https://github.com/apache/spark/pull/19122#issuecomment-340111834 ## How was this patch tested? N/A Author: WeichenXu Closes #19641 from WeichenXu123/fix_doc. --- docs/ml-tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 64dc46cf0c0e7..54d9cd21909df 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -55,7 +55,7 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c method in each of these evaluators. To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. -By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit` (NOTE: this is not yet supported in Python). +By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit`. The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance. Generally speaking, a value up to 10 should be sufficient for most clusters. # Cross-Validation From 673c67046598d33b9ecf864024ca7a937c1998d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Nov 2017 12:34:21 +0100 Subject: [PATCH 240/712] [SPARK-17310][SQL] Add an option to disable record-level filter in Parquet-side ## What changes were proposed in this pull request? There is a concern that Spark-side codegen row-by-row filtering might be faster than Parquet's one in general due to type-boxing and additional fuction calls which Spark's one tries to avoid. So, this PR adds an option to disable/enable record-by-record filtering in Parquet side. It sets the default to `false` to take the advantage of the improvement. This was also discussed in https://github.com/apache/spark/pull/14671. ## How was this patch tested? Manually benchmarks were performed. I generated a billion (1,000,000,000) records and tested equality comparison concatenated with `OR`. This filter combinations were made from 5 to 30. It seem indeed Spark-filtering is faster in the test case and the gap increased as the filter tree becomes larger. The details are as below: **Code** ``` scala test("Parquet-side filter vs Spark-side filter - record by record") { withTempPath { path => val N = 1000 * 1000 * 1000 val df = spark.range(N).toDF("a") df.write.parquet(path.getAbsolutePath) val benchmark = new Benchmark("Parquet-side vs Spark-side", N) Seq(5, 10, 20, 30).foreach { num => val filterExpr = (0 to num).map(i => s"a = $i").mkString(" OR ") benchmark.addCase(s"Parquet-side filter - number of filters [$num]", 3) { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString, SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> true.toString) { // We should strip Spark-side filter to compare correctly. stripSparkFilter( spark.read.parquet(path.getAbsolutePath).filter(filterExpr)).count() } } benchmark.addCase(s"Spark-side filter - number of filters [$num]", 3) { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString, SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> false.toString) { spark.read.parquet(path.getAbsolutePath).filter(filterExpr).count() } } } benchmark.run() } } ``` **Result** ``` Parquet-side vs Spark-side: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Parquet-side filter - number of filters [5] 4268 / 4367 234.3 4.3 0.8X Spark-side filter - number of filters [5] 3709 / 3741 269.6 3.7 0.9X Parquet-side filter - number of filters [10] 5673 / 5727 176.3 5.7 0.6X Spark-side filter - number of filters [10] 3588 / 3632 278.7 3.6 0.9X Parquet-side filter - number of filters [20] 8024 / 8440 124.6 8.0 0.4X Spark-side filter - number of filters [20] 3912 / 3946 255.6 3.9 0.8X Parquet-side filter - number of filters [30] 11936 / 12041 83.8 11.9 0.3X Spark-side filter - number of filters [30] 3929 / 3978 254.5 3.9 0.8X ``` Author: hyukjinkwon Closes #15049 from HyukjinKwon/SPARK-17310. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../parquet/ParquetFileFormat.scala | 14 ++--- .../parquet/ParquetFilterSuite.scala | 51 ++++++++++++++++++- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 831ef62d74c3c..0cb58fab47ac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -327,6 +327,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") + .doc("If true, enables Parquet's native record-level filtering using the pushed down " + + "filters. This configuration only has an effect when 'spark.sql.parquet.filterPushdown' " + + "is enabled.") + .booleanConf + .createWithDefault(false) + val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + @@ -1173,6 +1180,8 @@ class SQLConf extends Serializable with Logging { def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + def parquetRecordFilterEnabled: Boolean = getConf(PARQUET_RECORD_FILTER_ENABLED) + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index a48f8d517b6ab..044b1a89d57c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -335,6 +335,8 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = + sparkSession.sessionState.conf.parquetRecordFilterEnabled // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -374,13 +376,11 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow - val reader = pushed match { - case Some(filter) => - new ParquetRecordReader[UnsafeRow]( - new ParquetReadSupport, - FilterCompat.get(filter, null)) - case _ => - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport, parquetFilter) + } else { + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) } reader.initialize(split, hadoopAttemptContext) reader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90f6620d990cb..33801954ebd51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -45,8 +45,29 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} * * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. + * + * NOTE: + * + * This file intendedly enables record-level filtering explicitly. If new test cases are + * dependent on this configuration, don't forget you better explicitly set this configuration + * within the test. */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + + override def beforeEach(): Unit = { + super.beforeEach() + // Note that there are many tests here that require record-level filtering set to be true. + spark.conf.set(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key, "true") + } + + override def afterEach(): Unit = { + try { + spark.conf.unset(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key) + } finally { + super.afterEach() + } + } + private def checkFilterPredicate( df: DataFrame, predicate: Predicate, @@ -369,7 +390,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("Filter applied on merged Parquet schema with new column should work") { import testImplicits._ - Seq("true", "false").map { vectorized => + Seq("true", "false").foreach { vectorized => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { @@ -491,7 +512,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("Fiters should be pushed down for vectorized Parquet reader at row group level") { + test("Filters should be pushed down for vectorized Parquet reader at row group level") { import testImplicits._ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", @@ -555,6 +576,32 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("Filters should be pushed down for Parquet readers at row group level") { + import testImplicits._ + + withSQLConf( + // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables + // row group level filtering. + SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { path => + val data = (1 to 1024) + data.toDF("a").coalesce(1) + .write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath).filter("a == 500") + // Here, we strip the Spark side filter and check the actual results from Parquet. + val actual = stripSparkFilter(df).collect().length + // Since those are filtered at row group level, the result count should be less + // than the total length but should not be a single record. + // Note that, if record level filtering is enabled, it should be a single record. + // If no filter is pushed down to Parquet, it should be the total length of data. + assert(actual > 1 && actual < data.length) + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 11b60af737a04d931356aa74ebf3c6cf4a6b08d6 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 14 Nov 2017 16:41:43 +0100 Subject: [PATCH 241/712] [SPARK-17074][SQL] Generate equi-height histogram in column statistics ## What changes were proposed in this pull request? Equi-height histogram is effective in cardinality estimation, and more accurate than basic column stats (min, max, ndv, etc) especially in skew distribution. So we need to support it. For equi-height histogram, all buckets (intervals) have the same height (frequency). In this PR, we use a two-step method to generate an equi-height histogram: 1. use `ApproximatePercentile` to get percentiles `p(0), p(1/n), p(2/n) ... p((n-1)/n), p(1)`; 2. construct range values of buckets, e.g. `[p(0), p(1/n)], [p(1/n), p(2/n)] ... [p((n-1)/n), p(1)]`, and use `ApproxCountDistinctForIntervals` to count ndv in each bucket. Each bucket is of the form: `(lowerBound, higherBound, ndv)`. ## How was this patch tested? Added new test cases and modified some existing test cases. Author: Zhenhua Wang Author: Zhenhua Wang Closes #19479 from wzhfy/generate_histogram. --- .../catalyst/plans/logical/Statistics.scala | 203 ++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 34 ++- .../command/AnalyzeColumnCommand.scala | 57 +++- .../spark/sql/StatisticsCollectionSuite.scala | 15 +- .../sql/StatisticsCollectionTestBase.scala | 41 ++- .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../spark/sql/hive/StatisticsSuite.scala | 255 ++++++++++++------ 7 files changed, 484 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 5ae1a55a8b66f..96b199d7f20b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -17,16 +17,20 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.math.{MathContext, RoundingMode} import scala.util.control.NonFatal +import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -88,6 +92,7 @@ case class Statistics( * @param nullCount number of nulls * @param avgLen average length of the values. For fixed-length types, this should be a constant. * @param maxLen maximum length of the values. For fixed-length types, this should be a constant. + * @param histogram histogram of the values */ case class ColumnStat( distinctCount: BigInt, @@ -95,7 +100,8 @@ case class ColumnStat( max: Option[Any], nullCount: BigInt, avgLen: Long, - maxLen: Long) { + maxLen: Long, + histogram: Option[Histogram] = None) { // We currently don't store min/max for binary/string type. This can change in the future and // then we need to remove this require. @@ -121,6 +127,7 @@ case class ColumnStat( map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } + histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } map.toMap } @@ -155,6 +162,7 @@ object ColumnStat extends Logging { private val KEY_NULL_COUNT = "nullCount" private val KEY_AVG_LEN = "avgLen" private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" /** Returns true iff the we support gathering column statistics on column of the given type. */ def supportsType(dataType: DataType): Boolean = dataType match { @@ -168,6 +176,16 @@ object ColumnStat extends Logging { case _ => false } + /** Returns true iff the we support gathering histogram on column of the given type. */ + def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + /** * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. @@ -183,7 +201,8 @@ object ColumnStat extends Logging { .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong + maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, + histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) )) } catch { case NonFatal(e) => @@ -220,12 +239,16 @@ object ColumnStat extends Logging { * Constructs an expression to compute column statistics for a given column. * * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] * * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and * as a result should stay in sync with it. */ - def statExprs(col: Attribute, relativeSD: Double): CreateNamedStruct = { + def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } }) @@ -233,40 +256,55 @@ object ColumnStat extends Logging { // the approximate ndv (num distinct value) should never be larger than the number of rows val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) val numNulls = Subtract(Count(one), numNonNulls) val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) - def fixedLenTypeStruct(castType: DataType) = { + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, defaultSize, - defaultSize) + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) } col.dataType match { - case dt: IntegralType => fixedLenTypeStruct(dt) - case _: DecimalType => fixedLenTypeStruct(col.dataType) - case dt @ (DoubleType | FloatType) => fixedLenTypeStruct(dt) - case BooleanType => fixedLenTypeStruct(col.dataType) - case DateType => fixedLenTypeStruct(col.dataType) - case TimestampType => fixedLenTypeStruct(col.dataType) + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct case BinaryType | StringType => - // For string and binary type, we don't store min/max. + // For string and binary type, we don't compute min, max or histogram val nullLit = Literal(null, col.dataType) struct( ndv, nullLit, nullLit, numNulls, // Set avg/max size to default size if all the values are null or there is no value. Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize))) + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) case _ => throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") + s"${col.name} of data type: ${col.dataType}.") } } - /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: InternalRow, attr: Attribute): ColumnStat = { - ColumnStat( + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( distinctCount = BigInt(row.getLong(0)), // for string/binary min/max, get should return null min = Option(row.get(1, attr.dataType)), @@ -275,6 +313,129 @@ object ColumnStat extends Logging { avgLen = row.getLong(4), maxLen = row.getLong(5) ) + if (row.isNullAt(6)) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + +} + +/** + * This class is an implementation of equi-height histogram. + * Equi-height histogram represents the distribution of a column's values by a sequence of bins. + * Each bin has a value range and contains approximately the same number of rows. + * + * @param height number of rows in each bin + * @param bins equi-height histogram bins + */ +case class Histogram(height: Double, bins: Array[HistogramBin]) { + + // Only for histogram equality test. + override def equals(other: Any): Boolean = other match { + case otherHgm: Histogram => + height == otherHgm.height && bins.sameElements(otherHgm.bins) + case _ => false + } + + override def hashCode(): Int = { + val temp = java.lang.Double.doubleToLongBits(height) + var result = (temp ^ (temp >>> 32)).toInt + result = 31 * result + java.util.Arrays.hashCode(bins.asInstanceOf[Array[AnyRef]]) + result } +} + +/** + * A bin in an equi-height histogram. We use double type for lower/higher bound for simplicity. + * + * @param lo lower bound of the value range in this bin + * @param hi higher bound of the value range in this bin + * @param ndv approximate number of distinct values in this bin + */ +case class HistogramBin(lo: Double, hi: Double, ndv: Long) +object HistogramSerializer { + /** + * Serializes a given histogram to a string. For advanced statistics like histograms, sketches, + * etc, we don't provide readability for their serialized formats in metastore + * (string-to-string table properties). This is because it's hard or unnatural for these + * statistics to be human readable. For example, a histogram usually cannot fit in a single, + * self-described property. And for count-min-sketch, it's essentially unnatural to make it + * a readable string. + */ + final def serialize(histogram: Histogram): String = { + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(new LZ4BlockOutputStream(bos)) + out.writeDouble(histogram.height) + out.writeInt(histogram.bins.length) + // Write data with same type together for compression. + var i = 0 + while (i < histogram.bins.length) { + out.writeDouble(histogram.bins(i).lo) + i += 1 + } + i = 0 + while (i < histogram.bins.length) { + out.writeDouble(histogram.bins(i).hi) + i += 1 + } + i = 0 + while (i < histogram.bins.length) { + out.writeLong(histogram.bins(i).ndv) + i += 1 + } + out.writeInt(-1) + out.flush() + out.close() + + org.apache.commons.codec.binary.Base64.encodeBase64String(bos.toByteArray) + } + + /** Deserializes a given string to a histogram. */ + final def deserialize(str: String): Histogram = { + val bytes = org.apache.commons.codec.binary.Base64.decodeBase64(str) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(new LZ4BlockInputStream(bis)) + val height = ins.readDouble() + val numBins = ins.readInt() + + val los = new Array[Double](numBins) + var i = 0 + while (i < numBins) { + los(i) = ins.readDouble() + i += 1 + } + val his = new Array[Double](numBins) + i = 0 + while (i < numBins) { + his(i) = ins.readDouble() + i += 1 + } + val ndvs = new Array[Long](numBins) + i = 0 + while (i < numBins) { + ndvs(i) = ins.readLong() + i += 1 + } + ins.close() + + val bins = new Array[HistogramBin](numBins) + i = 0 + while (i < numBins) { + bins(i) = HistogramBin(los(i), his(i), ndvs(i)) + i += 1 + } + Histogram(height, bins) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0cb58fab47ac7..3452a1e715fb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -31,7 +31,6 @@ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -835,6 +834,33 @@ object SQLConf { .doubleConf .createWithDefault(0.05) + val HISTOGRAM_ENABLED = + buildConf("spark.sql.statistics.histogram.enabled") + .doc("Generates histograms when computing column statistics if enabled. Histograms can " + + "provide better estimation accuracy. Currently, Spark only supports equi-height " + + "histogram. Note that collecting histograms takes extra cost. For example, collecting " + + "column statistics usually takes only one table scan, but generating equi-height " + + "histogram will cause an extra table scan.") + .booleanConf + .createWithDefault(false) + + val HISTOGRAM_NUM_BINS = + buildConf("spark.sql.statistics.histogram.numBins") + .internal() + .doc("The number of bins when generating histograms.") + .intConf + .checkValue(num => num > 1, "The number of bins must be large than 1.") + .createWithDefault(254) + + val PERCENTILE_ACCURACY = + buildConf("spark.sql.statistics.percentile.accuracy") + .internal() + .doc("Accuracy of percentile approximation when generating equi-height histograms. " + + "Larger value means better accuracy. The relative error can be deduced by " + + "1.0 / PERCENTILE_ACCURACY.") + .intConf + .createWithDefault(10000) + val AUTO_SIZE_UPDATE_ENABLED = buildConf("spark.sql.statistics.size.autoUpdate.enabled") .doc("Enables automatic update for table size once table's data is changed. Note that if " + @@ -1241,6 +1267,12 @@ class SQLConf extends Serializable with Logging { def ndvMaxError: Double = getConf(NDV_MAX_ERROR) + def histogramEnabled: Boolean = getConf(HISTOGRAM_ENABLED) + + def histogramNumBins: Int = getConf(HISTOGRAM_NUM_BINS) + + def percentileAccuracy: Int = getConf(PERCENTILE_ACCURACY) + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index caf12ad745bb8..e3bb4d357b395 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.command +import scala.collection.mutable + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution @@ -68,11 +71,11 @@ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet - val resolver = sparkSession.sessionState.conf.resolver val attributesToAnalyze = columnNames.map { col => - val exprOption = relation.output.find(attr => resolver(attr.name, col)) + val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) } @@ -86,12 +89,21 @@ case class AnalyzeColumnCommand( } // Collect statistics per column. + // If no histogram is required, we run a job to compute basic column stats such as + // min, max, ndv, etc. Otherwise, besides basic column stats, histogram will also be + // generated. Currently we only support equi-height histogram. + // To generate an equi-height histogram, we need two jobs: + // 1. compute percentiles p(0), p(1/n) ... p((n-1)/n), p(1). + // 2. use the percentiles as value intervals of bins, e.g. [p(0), p(1/n)], + // [p(1/n), p(2/n)], ..., [p((n-1)/n), p(1)], and then count ndv in each bin. + // Basic column stats will be computed together in the second job. + val attributePercentiles = computePercentiles(attributesToAnalyze, sparkSession, relation) + // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. - val ndvMaxErr = sparkSession.sessionState.conf.ndvMaxError val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) + attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -99,9 +111,42 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 6 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 6), attr)) + // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr))) }.toMap (rowCount, columnStats) } + + /** Computes percentiles for each attribute. */ + private def computePercentiles( + attributesToAnalyze: Seq[Attribute], + sparkSession: SparkSession, + relation: LogicalPlan): AttributeMap[ArrayData] = { + val attrsToGenHistogram = if (conf.histogramEnabled) { + attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + } else { + Nil + } + val attributePercentiles = mutable.HashMap[Attribute, ArrayData]() + if (attrsToGenHistogram.nonEmpty) { + val percentiles = (0 to conf.histogramNumBins) + .map(i => i.toDouble / conf.histogramNumBins).toArray + + val namedExprs = attrsToGenHistogram.map { attr => + val aggFunc = + new ApproximatePercentile(attr, Literal(percentiles), Literal(conf.percentileAccuracy)) + val expr = aggFunc.toAggregateExpression() + Alias(expr, expr.toString)() + } + + val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) + .executedPlan.executeTake(1).head + attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => + attributePercentiles += attr -> percentilesRow.getArray(i) + } + } + AttributeMap(attributePercentiles.toSeq) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 7247c3a876df3..fba5d2652d3f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -142,10 +142,12 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - stats.zip(df.schema).foreach { case ((k, v), field) => - withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) - assert(roundtrip == Some(v)) + Seq(stats, statsWithHgms).foreach { s => + s.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + assert(roundtrip == Some(v)) + } } } } @@ -155,6 +157,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) checkColStats(df, stats) + + // test column stats with histograms + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> "true", SQLConf.HISTOGRAM_NUM_BINS.key -> "2") { + checkColStats(df, statsWithHgms) + } } test("column stats collection for null columns") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index a2f63edd786bf..f6df077ec5727 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf @@ -46,6 +46,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val d2 = Date.valueOf("2016-05-09") private val t1 = Timestamp.valueOf("2016-05-08 00:00:01") private val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + private val d1Internal = DateTimeUtils.fromJavaDate(d1) + private val d2Internal = DateTimeUtils.fromJavaDate(d2) + private val t1Internal = DateTimeUtils.fromJavaTimestamp(t1) + private val t2Internal = DateTimeUtils.fromJavaTimestamp(t2) /** * Define a very simple 3 row table used for testing column serialization. @@ -73,12 +77,39 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), - Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), - Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) ) + /** + * A mapping from column to the stats collected including histograms. + * The number of bins in the histograms is 2. + */ + protected val statsWithHgms = { + val colStats = mutable.LinkedHashMap(stats.toSeq: _*) + colStats.update("cbyte", stats("cbyte").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 2, 1)))))) + colStats.update("cshort", stats("cshort").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 3, 1)))))) + colStats.update("cint", stats("cint").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 4, 1)))))) + colStats.update("clong", stats("clong").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 5, 1)))))) + colStats.update("cdouble", stats("cdouble").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 6, 1)))))) + colStats.update("cfloat", stats("cfloat").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 7, 1)))))) + colStats.update("cdecimal", stats("cdecimal").copy(histogram = + Some(Histogram(1, Array(HistogramBin(1, 1, 1), HistogramBin(1, 8, 1)))))) + colStats.update("cdate", stats("cdate").copy(histogram = + Some(Histogram(1, Array(HistogramBin(d1Internal, d1Internal, 1), + HistogramBin(d1Internal, d2Internal, 1)))))) + colStats.update("ctimestamp", stats("ctimestamp").copy(histogram = + Some(Histogram(1, Array(HistogramBin(t1Internal, t1Internal, 1), + HistogramBin(t1Internal, t2Internal, 1)))))) + colStats + } + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 7cd772544a96a..44e680dbd2f93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1032,8 +1032,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat stats: CatalogStatistics, schema: StructType): Map[String, String] = { - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + val statsProperties = new mutable.HashMap[String, String]() + statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } @@ -1046,7 +1046,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - statsProperties + statsProperties.toMap } private def statsFromProperties( @@ -1072,9 +1072,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => (k.drop(keyPrefix.length), v) } - - ColumnStat.fromMap(table, field, colStatMap).foreach { - colStat => colStats += field.name -> colStat + ColumnStat.fromMap(table, field, colStatMap).foreach { cs => + colStats += field.name -> cs } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9e8fc32a05471..7427948fe138b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} +import java.sql.Timestamp import scala.reflect.ClassTag import scala.util.matching.Regex @@ -28,8 +29,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -963,98 +964,174 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + + def checkColStatsProps(expected: Map[String, String]): Unit = { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + val table = hiveClient.getTable("default", tableName) + val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) + assert(props == expected) + } + withTable(tableName) { df.write.saveAsTable(tableName) - // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + // Collect and validate statistics + checkColStatsProps(expectedSerializedColStats) - // Validate statistics - val table = hiveClient.getTable("default", tableName) + withSQLConf( + SQLConf.HISTOGRAM_ENABLED.key -> "true", SQLConf.HISTOGRAM_NUM_BINS.key -> "2") { - val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) - assert(props == Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - )) + checkColStatsProps(expectedSerializedColStats ++ expectedSerializedHistograms) + } + } + } + + test("serialization and deserialization of histograms to/from hive metastore") { + import testImplicits._ + + def checkBinsOrder(bins: Array[HistogramBin]): Unit = { + for (i <- bins.indices) { + val b = bins(i) + assert(b.lo <= b.hi) + if (i > 0) { + val pre = bins(i - 1) + assert(pre.hi <= b.lo) + } + } + } + + val startTimestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val df = (1 to 5000) + .map(i => (i, DateTimeUtils.toJavaTimestamp(startTimestamp + i))) + .toDF("cint", "ctimestamp") + val tableName = "histogram_serde_test" + + withTable(tableName) { + df.write.saveAsTable(tableName) + + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> "true") { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS cint, ctimestamp") + val table = hiveClient.getTable("default", tableName) + val intHistogramProps = table.properties + .filterKeys(_.startsWith("spark.sql.statistics.colStats.cint.histogram")) + assert(intHistogramProps.size == 1) + + val tsHistogramProps = table.properties + .filterKeys(_.startsWith("spark.sql.statistics.colStats.ctimestamp.histogram")) + assert(tsHistogramProps.size == 1) + + // Validate histogram after deserialization. + val cs = getCatalogStatistics(tableName).colStats + val intHistogram = cs("cint").histogram.get + val tsHistogram = cs("ctimestamp").histogram.get + assert(intHistogram.bins.length == spark.sessionState.conf.histogramNumBins) + checkBinsOrder(intHistogram.bins) + assert(tsHistogram.bins.length == spark.sessionState.conf.histogramNumBins) + checkBinsOrder(tsHistogram.bins) + } } } From 4741c07809393ab85be8b4a169d4ed3da93a4781 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 14 Nov 2017 10:34:32 -0600 Subject: [PATCH 242/712] [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend. This change is a little larger because there's a whole lot of logic behind these pages, all really tied to internal types and listeners, and some of that logic had to be implemented in the new listener and the needed data exposed through the API types. - Added missing StageData and ExecutorStageSummary fields which are used by the UI. Some json golden files needed to be updated to account for new fields. - Save RDD graph data in the store. This tries to re-use existing types as much as possible, so that the code doesn't need to be re-written. So it's probably not very optimal. - Some old classes (e.g. JobProgressListener) still remain, since they're used in other parts of the code; they're not used by the UI anymore, though, and will be cleaned up in a separate change. - Save information about active pools in the store. This data is not really used in the SHS, but it's not a lot of data so it's still recorded when replaying applications. - Because the new store sorts things slightly differently from the previous code, some json golden files had some elements within them shuffled around. - The retention unit test in UISeleniumSuite was disabled because the code to throw away old stages / tasks hasn't been added yet. - The job description field in the API tries to follow the old behavior, which makes it be empty most of the time, even though there's information to fill it in. For stages, a new field was added to hold the description (which is basically the job description), so that the UI can be rendered in the old way. - A new stage status ("SKIPPED") was added to account for the fact that the API couldn't represent that state before. Without this, the stage would show up as "PENDING" in the UI, which is now based on API types. - The API used to expose "executorRunTime" as the value of the task's duration, which wasn't really correct (also because that value was easily available from the metrics object); this change fixes that by storing the correct duration, which also means a few expectation files needed to be updated to account for the new durations and sorting differences due to the changed values. - Added changes to implement SPARK-20713 and SPARK-21922 in the new code. Tested with existing unit tests (and by using the UI a lot). Author: Marcelo Vanzin Closes #19698 from vanzin/SPARK-20648. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../deploy/history/FsHistoryProvider.scala | 10 +- .../spark/status/AppStatusListener.scala | 121 +- .../apache/spark/status/AppStatusStore.scala | 135 ++- .../org/apache/spark/status/LiveEntity.scala | 57 +- .../spark/status/api/v1/AllJobsResource.scala | 70 +- .../status/api/v1/AllStagesResource.scala | 290 +---- .../spark/status/api/v1/OneJobResource.scala | 15 +- .../status/api/v1/OneStageResource.scala | 112 +- .../org/apache/spark/status/api/v1/api.scala | 21 +- .../org/apache/spark/status/config.scala | 4 + .../org/apache/spark/status/storeTypes.scala | 40 + .../scala/org/apache/spark/ui/SparkUI.scala | 30 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 286 +++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 189 +-- .../apache/spark/ui/jobs/ExecutorTable.scala | 158 +-- .../org/apache/spark/ui/jobs/JobPage.scala | 326 ++--- .../org/apache/spark/ui/jobs/JobsTab.scala | 38 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 57 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 34 +- .../org/apache/spark/ui/jobs/StagePage.scala | 1064 ++++++++--------- .../org/apache/spark/ui/jobs/StageTable.scala | 100 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 28 +- .../spark/ui/scope/RDDOperationGraph.scala | 10 +- .../ui/scope/RDDOperationGraphListener.scala | 150 --- .../complete_stage_list_json_expectation.json | 21 +- .../failed_stage_list_json_expectation.json | 8 +- ...multi_attempt_app_json_1__expectation.json | 5 +- ...multi_attempt_app_json_2__expectation.json | 5 +- .../job_list_json_expectation.json | 15 +- .../one_job_json_expectation.json | 5 +- .../one_stage_attempt_json_expectation.json | 126 +- .../one_stage_json_expectation.json | 126 +- .../stage_list_json_expectation.json | 79 +- ...ist_with_accumulable_json_expectation.json | 7 +- .../stage_task_list_expectation.json | 60 +- ...multi_attempt_app_json_1__expectation.json | 24 +- ...multi_attempt_app_json_2__expectation.json | 24 +- ...k_list_w__offset___length_expectation.json | 150 ++- ...stage_task_list_w__sortBy_expectation.json | 190 +-- ...tBy_short_names___runtime_expectation.json | 190 +-- ...rtBy_short_names__runtime_expectation.json | 60 +- ...age_with_accumulable_json_expectation.json | 36 +- ...eded_failed_job_list_json_expectation.json | 15 +- .../succeeded_job_list_json_expectation.json | 10 +- .../deploy/history/HistoryServerSuite.scala | 13 +- .../spark/status/AppStatusListenerSuite.scala | 80 +- .../api/v1/AllStagesResourceSuite.scala | 62 - .../org/apache/spark/ui/StagePageSuite.scala | 64 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 60 +- .../RDDOperationGraphListenerSuite.scala | 226 ---- project/MimaExcludes.scala | 2 + 52 files changed, 2359 insertions(+), 2657 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala delete mode 100644 core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e5aaaf6c155eb..1d325e651b1d9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -426,7 +426,7 @@ class SparkContext(config: SparkConf) extends Logging { // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. - _statusStore = AppStatusStore.createLiveStore(conf, listenerBus) + _statusStore = AppStatusStore.createLiveStore(conf, l => listenerBus.addToStatusQueue(l)) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -449,11 +449,7 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.create(Some(this), _statusStore, _conf, - l => listenerBus.addToStatusQueue(l), - _env.securityManager, - appName, - "", + Some(SparkUI.create(Some(this), _statusStore, _conf, _env.securityManager, appName, "", startTime)) } else { // For tests, do not enable the UI diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f16dddea9f784..a6dc53321d650 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -316,7 +316,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val listener = if (needReplay) { - val _listener = new AppStatusListener(kvstore, conf, false) + val _listener = new AppStatusListener(kvstore, conf, false, + lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(_listener) Some(_listener) } else { @@ -324,13 +325,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val loadedUI = { - val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, - l => replayBus.addListener(l), - secManager, - app.info.name, + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, HistoryServer.getAttemptURI(appId, attempt.info.attemptId), attempt.info.startTime.getTime(), - appSparkVersion = attempt.info.appSparkVersion) + attempt.info.appSparkVersion) LoadedAppUI(ui) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 7f2c00c09d43d..f2d8e0a5480ba 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -28,16 +28,21 @@ import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.scope._ import org.apache.spark.util.kvstore.KVStore /** * A Spark listener that writes application information to a data store. The types written to the * store are defined in the `storeTypes.scala` file and are based on the public REST API. + * + * @param lastUpdateTime When replaying logs, the log's last update time, so that the duration of + * unfinished tasks can be more accurately calculated (see SPARK-21922). */ private[spark] class AppStatusListener( kvstore: KVStore, conf: SparkConf, - live: Boolean) extends SparkListener with Logging { + live: Boolean, + lastUpdateTime: Option[Long] = None) extends SparkListener with Logging { import config._ @@ -50,6 +55,8 @@ private[spark] class AppStatusListener( // operations that we can live without when rapidly processing incoming task events. private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + private val maxGraphRootNodes = conf.get(MAX_RETAINED_ROOT_NODES) + // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). private val liveStages = new HashMap[(Int, Int), LiveStage]() @@ -57,6 +64,7 @@ private[spark] class AppStatusListener( private val liveExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() + private val pools = new HashMap[String, SchedulerPool]() override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case SparkListenerLogStart(version) => sparkVersion = version @@ -210,16 +218,15 @@ private[spark] class AppStatusListener( missingStages.map(_.numTasks).sum } - val lastStageInfo = event.stageInfos.lastOption + val lastStageInfo = event.stageInfos.sortBy(_.stageId).lastOption val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } val job = new LiveJob( event.jobId, lastStageName, - Some(new Date(event.time)), + if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, numTasks) @@ -234,17 +241,51 @@ private[spark] class AppStatusListener( stage.jobIds += event.jobId liveUpdate(stage, now) } + + // Create the graph data for all the job's stages. + event.stageInfos.foreach { stage => + val graph = RDDOperationGraph.makeOperationGraph(stage, maxGraphRootNodes) + val uigraph = new RDDOperationGraphWrapper( + stage.stageId, + graph.edges, + graph.outgoingEdges, + graph.incomingEdges, + newRDDOperationCluster(graph.rootCluster)) + kvstore.write(uigraph) + } + } + + private def newRDDOperationCluster(cluster: RDDOperationCluster): RDDOperationClusterWrapper = { + new RDDOperationClusterWrapper( + cluster.id, + cluster.name, + cluster.childNodes, + cluster.childClusters.map(newRDDOperationCluster)) } override def onJobEnd(event: SparkListenerJobEnd): Unit = { liveJobs.remove(event.jobId).foreach { job => + val now = System.nanoTime() + + // Check if there are any pending stages that match this job; mark those as skipped. + job.stageIds.foreach { sid => + val pending = liveStages.filter { case ((id, _), _) => id == sid } + pending.foreach { case (key, stage) => + stage.status = v1.StageStatus.SKIPPED + job.skippedStages += stage.info.stageId + job.skippedTasks += stage.info.numTasks + liveStages.remove(key) + update(stage, now) + } + } + job.status = event.jobResult match { case JobSucceeded => JobExecutionStatus.SUCCEEDED case JobFailed(_) => JobExecutionStatus.FAILED } - job.completionTime = Some(new Date(event.time)) - update(job, System.nanoTime()) + job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None + update(job, now) } } @@ -262,12 +303,24 @@ private[spark] class AppStatusListener( .toSeq stage.jobIds = stage.jobs.map(_.jobId).toSet + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + stage.description = Option(event.properties).flatMap { p => + Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + stage.jobs.foreach { job => job.completedStages = job.completedStages - event.stageInfo.stageId job.activeStages += 1 liveUpdate(job, now) } + val pool = pools.getOrElseUpdate(stage.schedulingPool, new SchedulerPool(stage.schedulingPool)) + pool.stageIds = pool.stageIds + event.stageInfo.stageId + update(pool, now) + event.stageInfo.rddInfos.foreach { info => if (info.storageLevel.isValid) { liveUpdate(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info)), now) @@ -279,7 +332,7 @@ private[spark] class AppStatusListener( override def onTaskStart(event: SparkListenerTaskStart): Unit = { val now = System.nanoTime() - val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId, lastUpdateTime) liveTasks.put(event.taskInfo.taskId, task) liveUpdate(task, now) @@ -318,6 +371,8 @@ private[spark] class AppStatusListener( val now = System.nanoTime() val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + task.info = event.taskInfo + val errorMessage = event.reason match { case Success => None @@ -337,11 +392,15 @@ private[spark] class AppStatusListener( delta }.orNull - val (completedDelta, failedDelta) = event.reason match { + val (completedDelta, failedDelta, killedDelta) = event.reason match { case Success => - (1, 0) + (1, 0, 0) + case _: TaskKilled => + (0, 0, 1) + case _: TaskCommitDenied => + (0, 0, 1) case _ => - (0, 1) + (0, 1, 0) } liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => @@ -350,13 +409,29 @@ private[spark] class AppStatusListener( } stage.activeTasks -= 1 stage.completedTasks += completedDelta + if (completedDelta > 0) { + stage.completedIndices.add(event.taskInfo.index) + } stage.failedTasks += failedDelta + stage.killedTasks += killedDelta + if (killedDelta > 0) { + stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) + } maybeUpdate(stage, now) + // Store both stage ID and task index in a single long variable for tracking at job level. + val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index stage.jobs.foreach { job => job.activeTasks -= 1 job.completedTasks += completedDelta + if (completedDelta > 0) { + job.completedIndices.add(taskIndex) + } job.failedTasks += failedDelta + job.killedTasks += killedDelta + if (killedDelta > 0) { + job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) + } maybeUpdate(job, now) } @@ -364,6 +439,7 @@ private[spark] class AppStatusListener( esummary.taskTime += event.taskInfo.duration esummary.succeededTasks += completedDelta esummary.failedTasks += failedDelta + esummary.killedTasks += killedDelta if (metricsDelta != null) { esummary.metrics.update(metricsDelta) } @@ -422,6 +498,11 @@ private[spark] class AppStatusListener( liveUpdate(job, now) } + pools.get(stage.schedulingPool).foreach { pool => + pool.stageIds = pool.stageIds - event.stageInfo.stageId + update(pool, now) + } + stage.executorSummaries.values.foreach(update(_, now)) update(stage, now) } @@ -482,11 +563,15 @@ private[spark] class AppStatusListener( /** Flush all live entities' data to the underlying store. */ def flush(): Unit = { val now = System.nanoTime() - liveStages.values.foreach(update(_, now)) + liveStages.values.foreach { stage => + update(stage, now) + stage.executorSummaries.values.foreach(update(_, now)) + } liveJobs.values.foreach(update(_, now)) liveExecutors.values.foreach(update(_, now)) liveTasks.values.foreach(update(_, now)) liveRDDs.values.foreach(update(_, now)) + pools.values.foreach(update(_, now)) } private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { @@ -628,6 +713,20 @@ private[spark] class AppStatusListener( stage } + private def killedTasksSummary( + reason: TaskEndReason, + oldSummary: Map[String, Int]): Map[String, Int] = { + reason match { + case k: TaskKilled => + oldSummary.updated(k.reason, oldSummary.getOrElse(k.reason, 0) + 1) + case denied: TaskCommitDenied => + val reason = denied.toErrorString + oldSummary.updated(reason, oldSummary.getOrElse(reason, 0) + 1) + case _ => + oldSummary + } + } + private def update(entity: LiveEntity, now: Long): Unit = { entity.write(kvstore, now) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 80c8d7d11a3c2..9b42f55605755 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -23,8 +23,9 @@ import java.util.{Arrays, List => JList} import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.SparkListener import org.apache.spark.status.api.v1 +import org.apache.spark.ui.scope._ import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} @@ -43,8 +44,8 @@ private[spark] class AppStatusStore(store: KVStore) { } def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { - val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) - if (!statuses.isEmpty()) { + val it = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info) + if (statuses != null && !statuses.isEmpty()) { it.filter { job => statuses.contains(job.status) }.toSeq } else { it.toSeq @@ -65,31 +66,40 @@ private[spark] class AppStatusStore(store: KVStore) { filtered.asScala.map(_.info).toSeq } - def executorSummary(executorId: String): Option[v1.ExecutorSummary] = { - try { - Some(store.read(classOf[ExecutorSummaryWrapper], executorId).info) - } catch { - case _: NoSuchElementException => - None - } + def executorSummary(executorId: String): v1.ExecutorSummary = { + store.read(classOf[ExecutorSummaryWrapper], executorId).info } def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { - val it = store.view(classOf[StageDataWrapper]).asScala.map(_.info) - if (!statuses.isEmpty) { + val it = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) + if (statuses != null && !statuses.isEmpty()) { it.filter { s => statuses.contains(s.status) }.toSeq } else { it.toSeq } } - def stageData(stageId: Int): Seq[v1.StageData] = { + def stageData(stageId: Int, details: Boolean = false): Seq[v1.StageData] = { store.view(classOf[StageDataWrapper]).index("stageId").first(stageId).last(stageId) - .asScala.map(_.info).toSeq + .asScala.map { s => + if (details) stageWithDetails(s.info) else s.info + }.toSeq + } + + def lastStageAttempt(stageId: Int): v1.StageData = { + val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + .closeableIterator() + try { + it.next().info + } finally { + it.close() + } } - def stageAttempt(stageId: Int, stageAttemptId: Int): v1.StageData = { - store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).info + def stageAttempt(stageId: Int, stageAttemptId: Int, details: Boolean = false): v1.StageData = { + val stageKey = Array(stageId, stageAttemptId) + val stage = store.read(classOf[StageDataWrapper], stageKey).info + if (details) stageWithDetails(stage) else stage } def taskSummary( @@ -189,6 +199,12 @@ private[spark] class AppStatusStore(store: KVStore) { ) } + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { + val stageKey = Array(stageId, stageAttemptId) + store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() + .max(maxTasks).asScala.map(_.info).toSeq.reverse + } + def taskList( stageId: Int, stageAttemptId: Int, @@ -215,6 +231,66 @@ private[spark] class AppStatusStore(store: KVStore) { }.toSeq } + /** + * Calls a closure that may throw a NoSuchElementException and returns `None` when the exception + * is thrown. + */ + def asOption[T](fn: => T): Option[T] = { + try { + Some(fn) + } catch { + case _: NoSuchElementException => None + } + } + + private def stageWithDetails(stage: v1.StageData): v1.StageData = { + val tasks = taskList(stage.stageId, stage.attemptId, Int.MaxValue) + .map { t => (t.taskId, t) } + .toMap + + val stageKey = Array(stage.stageId, stage.attemptId) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) + .last(stageKey).closeableIterator().asScala + .map { exec => (exec.executorId -> exec.info) } + .toMap + + new v1.StageData( + stage.status, + stage.stageId, + stage.attemptId, + stage.numTasks, + stage.numActiveTasks, + stage.numCompleteTasks, + stage.numFailedTasks, + stage.numKilledTasks, + stage.numCompletedIndices, + stage.executorRunTime, + stage.executorCpuTime, + stage.submissionTime, + stage.firstTaskLaunchedTime, + stage.completionTime, + stage.failureReason, + stage.inputBytes, + stage.inputRecords, + stage.outputBytes, + stage.outputRecords, + stage.shuffleReadBytes, + stage.shuffleReadRecords, + stage.shuffleWriteBytes, + stage.shuffleWriteRecords, + stage.memoryBytesSpilled, + stage.diskBytesSpilled, + stage.name, + stage.description, + stage.details, + stage.schedulingPool, + stage.rddIds, + stage.accumulatorUpdates, + Some(tasks), + Some(execs), + stage.killedTasksSummary) + } + def rdd(rddId: Int): v1.RDDStorageInfo = { store.read(classOf[RDDStorageInfoWrapper], rddId).info } @@ -223,6 +299,27 @@ private[spark] class AppStatusStore(store: KVStore) { store.view(classOf[StreamBlockData]).asScala.toSeq } + def operationGraphForStage(stageId: Int): RDDOperationGraph = { + store.read(classOf[RDDOperationGraphWrapper], stageId).toRDDOperationGraph() + } + + def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { + val job = store.read(classOf[JobDataWrapper], jobId) + val stages = job.info.stageIds + + stages.map { id => + val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() + if (job.skippedStages.contains(id) && !g.rootCluster.name.contains("skipped")) { + g.rootCluster.setName(g.rootCluster.name + " (skipped)") + } + g + } + } + + def pool(name: String): PoolData = { + store.read(classOf[PoolData], name) + } + def close(): Unit = { store.close() } @@ -237,12 +334,12 @@ private[spark] object AppStatusStore { * Create an in-memory store for a live application. * * @param conf Configuration. - * @param bus Where to attach the listener to populate the store. + * @param addListenerFn Function to register a listener with a bus. */ - def createLiveStore(conf: SparkConf, bus: LiveListenerBus): AppStatusStore = { + def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() val stateStore = new AppStatusStore(store) - bus.addToStatusQueue(new AppStatusListener(store, conf, true)) + addListenerFn(new AppStatusListener(store, conf, true)) stateStore } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 706d94c3a59b9..ef2936c9b69a4 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -28,6 +28,7 @@ import org.apache.spark.status.api.v1 import org.apache.spark.storage.RDDInfo import org.apache.spark.ui.SparkUI import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.kvstore.KVStore /** @@ -64,6 +65,12 @@ private class LiveJob( var completedTasks = 0 var failedTasks = 0 + // Holds both the stage ID and the task index, packed into a single long value. + val completedIndices = new OpenHashSet[Long]() + + var killedTasks = 0 + var killedSummary: Map[String, Int] = Map() + var skippedTasks = 0 var skippedStages = Set[Int]() @@ -89,19 +96,23 @@ private class LiveJob( completedTasks, skippedTasks, failedTasks, + killedTasks, + completedIndices.size, activeStages, completedStages.size, skippedStages.size, - failedStages) + failedStages, + killedSummary) new JobDataWrapper(info, skippedStages) } } private class LiveTask( - info: TaskInfo, + var info: TaskInfo, stageId: Int, - stageAttemptId: Int) extends LiveEntity { + stageAttemptId: Int, + lastUpdateTime: Option[Long]) extends LiveEntity { import LiveEntityHelpers._ @@ -126,6 +137,7 @@ private class LiveTask( metrics.resultSerializationTime, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, + metrics.peakExecutionMemory, new v1.InputMetrics( metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead), @@ -186,6 +198,7 @@ private class LiveTask( 0L, 0L, 0L, metrics.memoryBytesSpilled - old.memoryBytesSpilled, metrics.diskBytesSpilled - old.diskBytesSpilled, + 0L, inputDelta, outputDelta, shuffleReadDelta, @@ -193,12 +206,19 @@ private class LiveTask( } override protected def doUpdate(): Any = { + val duration = if (info.finished) { + info.duration + } else { + info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) + } + val task = new v1.TaskData( info.taskId, info.index, info.attemptNumber, new Date(info.launchTime), - if (info.finished) Some(info.duration) else None, + if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, + Some(duration), info.executorId, info.host, info.status, @@ -340,10 +360,15 @@ private class LiveExecutorStageSummary( taskTime, failedTasks, succeededTasks, + killedTasks, metrics.inputBytes, + metrics.inputRecords, metrics.outputBytes, + metrics.outputRecords, metrics.shuffleReadBytes, + metrics.shuffleReadRecords, metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -361,11 +386,16 @@ private class LiveStage extends LiveEntity { var info: StageInfo = null var status = v1.StageStatus.PENDING + var description: Option[String] = None var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME var activeTasks = 0 var completedTasks = 0 var failedTasks = 0 + val completedIndices = new OpenHashSet[Int]() + + var killedTasks = 0 + var killedSummary: Map[String, Int] = Map() var firstLaunchTime = Long.MaxValue @@ -384,15 +414,19 @@ private class LiveStage extends LiveEntity { info.stageId, info.attemptId, + info.numTasks, activeTasks, completedTasks, failedTasks, + killedTasks, + completedIndices.size, metrics.executorRunTime, metrics.executorCpuTime, info.submissionTime.map(new Date(_)), if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, info.completionTime.map(new Date(_)), + info.failureReason, metrics.inputBytes, metrics.inputRecords, @@ -406,12 +440,15 @@ private class LiveStage extends LiveEntity { metrics.diskBytesSpilled, info.name, + description, info.details, schedulingPool, + info.rddInfos.map(_.id), newAccumulatorInfos(info.accumulables.values), None, - None) + None, + killedSummary) new StageDataWrapper(update, jobIds) } @@ -535,6 +572,16 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { } +private class SchedulerPool(name: String) extends LiveEntity { + + var stageIds = Set[Int]() + + override protected def doUpdate(): Any = { + new PoolData(name, stageIds) + } + +} + private object LiveEntityHelpers { def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala index d0d9ef1165e81..b4fa3e633f6c1 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala @@ -22,7 +22,6 @@ import javax.ws.rs.core.MediaType import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.ui.jobs.UIData.JobUIData @Produces(Array(MediaType.APPLICATION_JSON)) @@ -30,74 +29,7 @@ private[v1] class AllJobsResource(ui: SparkUI) { @GET def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { - val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = - AllJobsResource.getStatusToJobs(ui) - val adjStatuses: JList[JobExecutionStatus] = { - if (statuses.isEmpty) { - Arrays.asList(JobExecutionStatus.values(): _*) - } else { - statuses - } - } - val jobInfos = for { - (status, jobs) <- statusToJobs - job <- jobs if adjStatuses.contains(status) - } yield { - AllJobsResource.convertJobData(job, ui.jobProgressListener, false) - } - jobInfos.sortBy{- _.jobId} + ui.store.jobsList(statuses) } } - -private[v1] object AllJobsResource { - - def getStatusToJobs(ui: SparkUI): Seq[(JobExecutionStatus, Seq[JobUIData])] = { - val statusToJobs = ui.jobProgressListener.synchronized { - Seq( - JobExecutionStatus.RUNNING -> ui.jobProgressListener.activeJobs.values.toSeq, - JobExecutionStatus.SUCCEEDED -> ui.jobProgressListener.completedJobs.toSeq, - JobExecutionStatus.FAILED -> ui.jobProgressListener.failedJobs.reverse.toSeq - ) - } - statusToJobs - } - - def convertJobData( - job: JobUIData, - listener: JobProgressListener, - includeStageDetails: Boolean): JobData = { - listener.synchronized { - val lastStageInfo = - if (job.stageIds.isEmpty) { - None - } else { - listener.stageIdToInfo.get(job.stageIds.max) - } - val lastStageData = lastStageInfo.flatMap { s => - listener.stageIdToData.get((s.stageId, s.attemptId)) - } - val lastStageName = lastStageInfo.map { _.name }.getOrElse("(Unknown Stage Name)") - val lastStageDescription = lastStageData.flatMap { _.description } - new JobData( - jobId = job.jobId, - name = lastStageName, - description = lastStageDescription, - submissionTime = job.submissionTime.map{new Date(_)}, - completionTime = job.completionTime.map{new Date(_)}, - stageIds = job.stageIds, - jobGroup = job.jobGroup, - status = job.status, - numTasks = job.numTasks, - numActiveTasks = job.numActiveTasks, - numCompletedTasks = job.numCompletedTasks, - numSkippedTasks = job.numSkippedTasks, - numFailedTasks = job.numFailedTasks, - numActiveStages = job.numActiveStages, - numCompletedStages = job.completedStageIndices.size, - numSkippedStages = job.numSkippedStages, - numFailedStages = job.numFailedStages - ) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 5f69949c618fd..e1c91cb527a51 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -16,304 +16,18 @@ */ package org.apache.spark.status.api.v1 -import java.util.{Arrays, Date, List => JList} +import java.util.{List => JList} import javax.ws.rs.{GET, Produces, QueryParam} import javax.ws.rs.core.MediaType -import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} -import org.apache.spark.ui.jobs.UIData.{InputMetricsUIData => InternalInputMetrics, OutputMetricsUIData => InternalOutputMetrics, ShuffleReadMetricsUIData => InternalShuffleReadMetrics, ShuffleWriteMetricsUIData => InternalShuffleWriteMetrics, TaskMetricsUIData => InternalTaskMetrics} -import org.apache.spark.util.Distribution @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AllStagesResource(ui: SparkUI) { @GET def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { - val listener = ui.jobProgressListener - val stageAndStatus = AllStagesResource.stagesAndStatus(ui) - val adjStatuses = { - if (statuses.isEmpty()) { - Arrays.asList(StageStatus.values(): _*) - } else { - statuses - } - } - for { - (status, stageList) <- stageAndStatus - stageInfo: StageInfo <- stageList if adjStatuses.contains(status) - stageUiData: StageUIData <- listener.synchronized { - listener.stageIdToData.get((stageInfo.stageId, stageInfo.attemptId)) - } - } yield { - stageUiData.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, includeDetails = false) - } + ui.store.stageList(statuses) } -} - -private[v1] object AllStagesResource { - def stageUiToStageData( - status: StageStatus, - stageInfo: StageInfo, - stageUiData: StageUIData, - includeDetails: Boolean): StageData = { - - val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) - - val firstTaskLaunchedTime: Option[Date] = - if (taskLaunchTimes.nonEmpty) { - Some(new Date(taskLaunchTimes.min)) - } else { - None - } - - val taskData = if (includeDetails) { - Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) - } else { - None - } - val executorSummary = if (includeDetails) { - Some(stageUiData.executorSummary.map { case (k, summary) => - k -> new ExecutorStageSummary( - taskTime = summary.taskTime, - failedTasks = summary.failedTasks, - succeededTasks = summary.succeededTasks, - inputBytes = summary.inputBytes, - outputBytes = summary.outputBytes, - shuffleRead = summary.shuffleRead, - shuffleWrite = summary.shuffleWrite, - memoryBytesSpilled = summary.memoryBytesSpilled, - diskBytesSpilled = summary.diskBytesSpilled - ) - }.toMap) - } else { - None - } - - val accumulableInfo = stageUiData.accumulables.values.map { convertAccumulableInfo }.toSeq - - new StageData( - status = status, - stageId = stageInfo.stageId, - attemptId = stageInfo.attemptId, - numActiveTasks = stageUiData.numActiveTasks, - numCompleteTasks = stageUiData.numCompleteTasks, - numFailedTasks = stageUiData.numFailedTasks, - executorRunTime = stageUiData.executorRunTime, - executorCpuTime = stageUiData.executorCpuTime, - submissionTime = stageInfo.submissionTime.map(new Date(_)), - firstTaskLaunchedTime, - completionTime = stageInfo.completionTime.map(new Date(_)), - inputBytes = stageUiData.inputBytes, - inputRecords = stageUiData.inputRecords, - outputBytes = stageUiData.outputBytes, - outputRecords = stageUiData.outputRecords, - shuffleReadBytes = stageUiData.shuffleReadTotalBytes, - shuffleReadRecords = stageUiData.shuffleReadRecords, - shuffleWriteBytes = stageUiData.shuffleWriteBytes, - shuffleWriteRecords = stageUiData.shuffleWriteRecords, - memoryBytesSpilled = stageUiData.memoryBytesSpilled, - diskBytesSpilled = stageUiData.diskBytesSpilled, - schedulingPool = stageUiData.schedulingPool, - name = stageInfo.name, - details = stageInfo.details, - accumulatorUpdates = accumulableInfo, - tasks = taskData, - executorSummary = executorSummary - ) - } - - def stagesAndStatus(ui: SparkUI): Seq[(StageStatus, Seq[StageInfo])] = { - val listener = ui.jobProgressListener - listener.synchronized { - Seq( - StageStatus.ACTIVE -> listener.activeStages.values.toSeq, - StageStatus.COMPLETE -> listener.completedStages.reverse.toSeq, - StageStatus.FAILED -> listener.failedStages.reverse.toSeq, - StageStatus.PENDING -> listener.pendingStages.values.toSeq - ) - } - } - - def convertTaskData(uiData: TaskUIData, lastUpdateTime: Option[Long]): TaskData = { - new TaskData( - taskId = uiData.taskInfo.taskId, - index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attemptNumber, - launchTime = new Date(uiData.taskInfo.launchTime), - duration = uiData.taskDuration(lastUpdateTime), - executorId = uiData.taskInfo.executorId, - host = uiData.taskInfo.host, - status = uiData.taskInfo.status, - taskLocality = uiData.taskInfo.taskLocality.toString(), - speculative = uiData.taskInfo.speculative, - accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, - errorMessage = uiData.errorMessage, - taskMetrics = uiData.metrics.map { convertUiTaskMetrics } - ) - } - - def taskMetricDistributions( - allTaskData: Iterable[TaskUIData], - quantiles: Array[Double]): TaskMetricDistributions = { - - val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq - - def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics: InputMetricDistributions = - new MetricHelper[InternalInputMetrics, InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalInputMetrics = raw.inputMetrics - - def build: InputMetricDistributions = new InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics: OutputMetricDistributions = - new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalOutputMetrics = raw.outputMetrics - - def build: OutputMetricDistributions = new OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics: ShuffleReadMetricDistributions = - new MetricHelper[InternalShuffleReadMetrics, ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: ShuffleReadMetricDistributions = new ShuffleReadMetricDistributions( - readBytes = submetricQuantiles(_.totalBytesRead), - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics: ShuffleWriteMetricDistributions = - new MetricHelper[InternalShuffleWriteMetrics, ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleWriteMetrics = - raw.shuffleWriteMetrics - def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new TaskMetricDistributions( - quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGCTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) - } - - def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { - new AccumulableInfo( - acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull) - } - - def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { - new TaskMetrics( - executorDeserializeTime = internal.executorDeserializeTime, - executorDeserializeCpuTime = internal.executorDeserializeCpuTime, - executorRunTime = internal.executorRunTime, - executorCpuTime = internal.executorCpuTime, - resultSize = internal.resultSize, - jvmGcTime = internal.jvmGCTime, - resultSerializationTime = internal.resultSerializationTime, - memoryBytesSpilled = internal.memoryBytesSpilled, - diskBytesSpilled = internal.diskBytesSpilled, - inputMetrics = convertInputMetrics(internal.inputMetrics), - outputMetrics = convertOutputMetrics(internal.outputMetrics), - shuffleReadMetrics = convertShuffleReadMetrics(internal.shuffleReadMetrics), - shuffleWriteMetrics = convertShuffleWriteMetrics(internal.shuffleWriteMetrics) - ) - } - - def convertInputMetrics(internal: InternalInputMetrics): InputMetrics = { - new InputMetrics( - bytesRead = internal.bytesRead, - recordsRead = internal.recordsRead - ) - } - - def convertOutputMetrics(internal: InternalOutputMetrics): OutputMetrics = { - new OutputMetrics( - bytesWritten = internal.bytesWritten, - recordsWritten = internal.recordsWritten - ) - } - - def convertShuffleReadMetrics(internal: InternalShuffleReadMetrics): ShuffleReadMetrics = { - new ShuffleReadMetrics( - remoteBlocksFetched = internal.remoteBlocksFetched, - localBlocksFetched = internal.localBlocksFetched, - fetchWaitTime = internal.fetchWaitTime, - remoteBytesRead = internal.remoteBytesRead, - remoteBytesReadToDisk = internal.remoteBytesReadToDisk, - localBytesRead = internal.localBytesRead, - recordsRead = internal.recordsRead - ) - } - - def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = { - new ShuffleWriteMetrics( - bytesWritten = internal.bytesWritten, - writeTime = internal.writeTime, - recordsWritten = internal.recordsWritten - ) - } -} - -/** - * Helper for getting distributions from nested metric types. - */ -private[v1] abstract class MetricHelper[I, O]( - rawMetrics: Seq[InternalTaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: InternalTaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala index 653150385c732..3ee884e084c12 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -16,25 +16,22 @@ */ package org.apache.spark.status.api.v1 +import java.util.NoSuchElementException import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType -import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class OneJobResource(ui: SparkUI) { @GET def oneJob(@PathParam("jobId") jobId: Int): JobData = { - val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = - AllJobsResource.getStatusToJobs(ui) - val jobOpt = statusToJobs.flatMap(_._2).find { jobInfo => jobInfo.jobId == jobId} - jobOpt.map { job => - AllJobsResource.convertJobData(job, ui.jobProgressListener, false) - }.getOrElse { - throw new NotFoundException("unknown job: " + jobId) + try { + ui.store.job(jobId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException("unknown job: " + jobId) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index f15073bccced2..20dd73e916613 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -24,7 +24,6 @@ import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) @@ -32,13 +31,14 @@ private[v1] class OneStageResource(ui: SparkUI) { @GET @Path("") - def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = { - withStage(stageId) { stageAttempts => - stageAttempts.map { stage => - stage.ui.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, - includeDetails = true) - } + def stageData( + @PathParam("stageId") stageId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean): Seq[StageData] = { + val ret = ui.store.stageData(stageId, details = details) + if (ret.nonEmpty) { + ret + } else { + throw new NotFoundException(s"unknown stage: $stageId") } } @@ -46,11 +46,13 @@ private[v1] class OneStageResource(ui: SparkUI) { @Path("/{stageAttemptId: \\d+}") def oneAttemptData( @PathParam("stageId") stageId: Int, - @PathParam("stageAttemptId") stageAttemptId: Int): StageData = { - withStageAttempt(stageId, stageAttemptId) { stage => - stage.ui.lastUpdateTime = ui.lastUpdateTime - AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, - includeDetails = true) + @PathParam("stageAttemptId") stageAttemptId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = { + try { + ui.store.stageAttempt(stageId, stageAttemptId, details = details) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"unknown attempt $stageAttemptId for stage $stageId.") } } @@ -61,17 +63,16 @@ private[v1] class OneStageResource(ui: SparkUI) { @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0.05,0.25,0.5,0.75,0.95") @QueryParam("quantiles") quantileString: String) : TaskMetricDistributions = { - withStageAttempt(stageId, stageAttemptId) { stage => - val quantiles = quantileString.split(",").map { s => - try { - s.toDouble - } catch { - case nfe: NumberFormatException => - throw new BadParameterException("quantiles", "double", s) - } + val quantiles = quantileString.split(",").map { s => + try { + s.toDouble + } catch { + case nfe: NumberFormatException => + throw new BadParameterException("quantiles", "double", s) } - AllStagesResource.taskMetricDistributions(stage.ui.taskData.values, quantiles) } + + ui.store.taskSummary(stageId, stageAttemptId, quantiles) } @GET @@ -82,72 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int, @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { - withStageAttempt(stageId, stageAttemptId) { stage => - val tasks = stage.ui.taskData.values - .map{ AllStagesResource.convertTaskData(_, ui.lastUpdateTime)}.toIndexedSeq - .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) - } - } - - private case class StageStatusInfoUi(status: StageStatus, info: StageInfo, ui: StageUIData) - - private def withStage[T](stageId: Int)(f: Seq[StageStatusInfoUi] => T): T = { - val stageAttempts = findStageStatusUIData(ui.jobProgressListener, stageId) - if (stageAttempts.isEmpty) { - throw new NotFoundException("unknown stage: " + stageId) - } else { - f(stageAttempts) - } + ui.store.taskList(stageId, stageAttemptId, offset, length, sortBy) } - private def findStageStatusUIData( - listener: JobProgressListener, - stageId: Int): Seq[StageStatusInfoUi] = { - listener.synchronized { - def getStatusInfoUi(status: StageStatus, infos: Seq[StageInfo]): Seq[StageStatusInfoUi] = { - infos.filter { _.stageId == stageId }.map { info => - val ui = listener.stageIdToData.getOrElse((info.stageId, info.attemptId), - // this is an internal error -- we should always have uiData - throw new SparkException( - s"no stage ui data found for stage: ${info.stageId}:${info.attemptId}") - ) - StageStatusInfoUi(status, info, ui) - } - } - getStatusInfoUi(ACTIVE, listener.activeStages.values.toSeq) ++ - getStatusInfoUi(COMPLETE, listener.completedStages) ++ - getStatusInfoUi(FAILED, listener.failedStages) ++ - getStatusInfoUi(PENDING, listener.pendingStages.values.toSeq) - } - } - - private def withStageAttempt[T]( - stageId: Int, - stageAttemptId: Int) - (f: StageStatusInfoUi => T): T = { - withStage(stageId) { attempts => - val oneAttempt = attempts.find { stage => stage.info.attemptId == stageAttemptId } - oneAttempt match { - case Some(stage) => - f(stage) - case None => - val stageAttempts = attempts.map { _.info.attemptId } - throw new NotFoundException(s"unknown attempt for stage $stageId. " + - s"Found attempts: ${stageAttempts.mkString("[", ",", "]")}") - } - } - } -} - -object OneStageResource { - def ordering(taskSorting: TaskSorting): Ordering[TaskData] = { - val extractor: (TaskData => Long) = td => - taskSorting match { - case ID => td.taskId - case INCREASING_RUNTIME => td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) - case DECREASING_RUNTIME => -td.taskMetrics.map{_.executorRunTime}.getOrElse(-1L) - } - Ordering.by(extractor) - } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index b338b1f3fd073..14280099f6422 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -58,10 +58,15 @@ class ExecutorStageSummary private[spark]( val taskTime : Long, val failedTasks : Int, val succeededTasks : Int, + val killedTasks : Int, val inputBytes : Long, + val inputRecords : Long, val outputBytes : Long, + val outputRecords : Long, val shuffleRead : Long, + val shuffleReadRecords : Long, val shuffleWrite : Long, + val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, val diskBytesSpilled : Long) @@ -111,10 +116,13 @@ class JobData private[spark]( val numCompletedTasks: Int, val numSkippedTasks: Int, val numFailedTasks: Int, + val numKilledTasks: Int, + val numCompletedIndices: Int, val numActiveStages: Int, val numCompletedStages: Int, val numSkippedStages: Int, - val numFailedStages: Int) + val numFailedStages: Int, + val killedTasksSummary: Map[String, Int]) class RDDStorageInfo private[spark]( val id: Int, @@ -152,15 +160,19 @@ class StageData private[spark]( val status: StageStatus, val stageId: Int, val attemptId: Int, + val numTasks: Int, val numActiveTasks: Int, val numCompleteTasks: Int, val numFailedTasks: Int, + val numKilledTasks: Int, + val numCompletedIndices: Int, val executorRunTime: Long, val executorCpuTime: Long, val submissionTime: Option[Date], val firstTaskLaunchedTime: Option[Date], val completionTime: Option[Date], + val failureReason: Option[String], val inputBytes: Long, val inputRecords: Long, @@ -174,18 +186,22 @@ class StageData private[spark]( val diskBytesSpilled: Long, val name: String, + val description: Option[String], val details: String, val schedulingPool: String, + val rddIds: Seq[Int], val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary: Option[Map[String, ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]], + val killedTasksSummary: Map[String, Int]) class TaskData private[spark]( val taskId: Long, val index: Int, val attempt: Int, val launchTime: Date, + val resultFetchStart: Option[Date], @JsonDeserialize(contentAs = classOf[JLong]) val duration: Option[Long], val executorId: String, @@ -207,6 +223,7 @@ class TaskMetrics private[spark]( val resultSerializationTime: Long, val memoryBytesSpilled: Long, val diskBytesSpilled: Long, + val peakExecutionMemory: Long, val inputMetrics: InputMetrics, val outputMetrics: OutputMetrics, val shuffleReadMetrics: ShuffleReadMetrics, diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala index 49144fc883e69..7af9dff977a86 100644 --- a/core/src/main/scala/org/apache/spark/status/config.scala +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -27,4 +27,8 @@ private[spark] object config { .timeConf(TimeUnit.NANOSECONDS) .createWithDefaultString("100ms") + val MAX_RETAINED_ROOT_NODES = ConfigBuilder("spark.ui.dagGraph.retainedRootRDDs") + .intConf + .createWithDefault(Int.MaxValue) + } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index f44b7935bfaa3..c1ea87542d6cc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ +import org.apache.spark.ui.scope._ import org.apache.spark.util.kvstore.KVIndex private[spark] case class AppStatusStoreMetadata(version: Long) @@ -106,6 +107,11 @@ private[spark] class TaskDataWrapper( Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) } + @JsonIgnore @KVIndex("startTime") + def startTime: Array[AnyRef] = { + Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + } + } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { @@ -147,3 +153,37 @@ private[spark] class StreamBlockData( def key: Array[String] = Array(name, executorId) } + +private[spark] class RDDOperationClusterWrapper( + val id: String, + val name: String, + val childNodes: Seq[RDDOperationNode], + val childClusters: Seq[RDDOperationClusterWrapper]) { + + def toRDDOperationCluster(): RDDOperationCluster = { + val cluster = new RDDOperationCluster(id, name) + childNodes.foreach(cluster.attachChildNode) + childClusters.foreach { child => + cluster.attachChildCluster(child.toRDDOperationCluster()) + } + cluster + } + +} + +private[spark] class RDDOperationGraphWrapper( + @KVIndexParam val stageId: Int, + val edges: Seq[RDDOperationEdge], + val outgoingEdges: Seq[RDDOperationEdge], + val incomingEdges: Seq[RDDOperationEdge], + val rootCluster: RDDOperationClusterWrapper) { + + def toRDDOperationGraph(): RDDOperationGraph = { + new RDDOperationGraph(edges, outgoingEdges, incomingEdges, rootCluster.toRDDOperationCluster()) + } + +} + +private[spark] class PoolData( + @KVIndexParam val name: String, + val stageIds: Set[Int]) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index e93ade001c607..35da3c3bfd1a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,11 +17,11 @@ package org.apache.spark.ui -import java.util.{Date, ServiceLoader} +import java.util.{Date, List => JList, ServiceLoader} import scala.collection.JavaConverters._ -import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{JobExecutionStatus, SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore @@ -29,8 +29,7 @@ import org.apache.spark.status.api.v1._ import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.EnvironmentTab import org.apache.spark.ui.exec.ExecutorsTab -import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} -import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.jobs.{JobsTab, StagesTab} import org.apache.spark.ui.storage.StorageTab import org.apache.spark.util.Utils @@ -42,11 +41,8 @@ private[spark] class SparkUI private ( val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, - val jobProgressListener: JobProgressListener, - val operationGraphListener: RDDOperationGraphListener, var appName: String, val basePath: String, - val lastUpdateTime: Option[Long] = None, val startTime: Long, val appSparkVersion: String) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), @@ -61,8 +57,8 @@ private[spark] class SparkUI private ( private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ - def initialize() { - val jobsTab = new JobsTab(this) + def initialize(): Unit = { + val jobsTab = new JobsTab(this, store) attachTab(jobsTab) val stagesTab = new StagesTab(this, store) attachTab(stagesTab) @@ -72,6 +68,7 @@ private[spark] class SparkUI private ( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) + // These should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( "/jobs/job/kill", "/jobs/", jobsTab.handleKillRequest, httpMethods = Set("GET", "POST"))) @@ -79,6 +76,7 @@ private[spark] class SparkUI private ( "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } + initialize() def getSparkUser: String = { @@ -170,25 +168,13 @@ private[spark] object SparkUI { sc: Option[SparkContext], store: AppStatusStore, conf: SparkConf, - addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, basePath: String, startTime: Long, - lastUpdateTime: Option[Long] = None, appSparkVersion: String = org.apache.spark.SPARK_VERSION): SparkUI = { - val jobProgressListener = sc.map(_.jobProgressListener).getOrElse { - val listener = new JobProgressListener(conf) - addListenerFn(listener) - listener - } - val operationGraphListener = new RDDOperationGraphListener(conf) - - addListenerFn(operationGraphListener) - - new SparkUI(store, sc, conf, securityManager, jobProgressListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime, appSparkVersion) + new SparkUI(store, sc, conf, securityManager, appName, basePath, startTime, appSparkVersion) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a647a1173a8cb..b60d39b21b4bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -22,20 +22,20 @@ import java.util.Date import javax.servlet.http.HttpServletRequest import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.ListBuffer import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData.{JobUIData, StageUIData} import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ -private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { +private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { private val JOBS_LEGEND =
    Removed
    .toString.filter(_ != '\n') - private def getLastStageNameAndDescription(job: JobUIData): (String, String) = { - val lastStageInfo = Option(job.stageIds) - .filter(_.nonEmpty) - .flatMap { ids => parent.jobProgresslistener.stageIdToInfo.get(ids.max)} - val lastStageData = lastStageInfo.flatMap { s => - parent.jobProgresslistener.stageIdToData.get((s.stageId, s.attemptId)) - } - val name = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val description = lastStageData.flatMap(_.description).getOrElse("") - (name, description) - } - - private def makeJobEvent(jobUIDatas: Seq[JobUIData]): Seq[String] = { - jobUIDatas.filter { jobUIData => - jobUIData.status != JobExecutionStatus.UNKNOWN && jobUIData.submissionTime.isDefined - }.map { jobUIData => - val jobId = jobUIData.jobId - val status = jobUIData.status - val (jobName, jobDescription) = getLastStageNameAndDescription(jobUIData) + private def makeJobEvent(jobs: Seq[v1.JobData]): Seq[String] = { + jobs.filter { job => + job.status != JobExecutionStatus.UNKNOWN && job.submissionTime.isDefined + }.map { job => + val jobId = job.jobId + val status = job.status val displayJobDescription = - if (jobDescription.isEmpty) { - jobName + if (job.description.isEmpty) { + job.name } else { - UIUtils.makeDescription(jobDescription, "", plainText = true).text + UIUtils.makeDescription(job.description.get, "", plainText = true).text } - val submissionTime = jobUIData.submissionTime.get - val completionTimeOpt = jobUIData.completionTime - val completionTime = completionTimeOpt.getOrElse(System.currentTimeMillis()) + val submissionTime = job.submissionTime.get.getTime() + val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { case JobExecutionStatus.SUCCEEDED => "succeeded" case JobExecutionStatus.FAILED => "failed" @@ -124,7 +110,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } } - private def makeExecutorEvent(executors: Seq[ExecutorSummary]): + private def makeExecutorEvent(executors: Seq[v1.ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() executors.foreach { e => @@ -169,8 +155,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } private def makeTimeline( - jobs: Seq[JobUIData], - executors: Seq[ExecutorSummary], + jobs: Seq[v1.JobData], + executors: Seq[v1.ExecutorSummary], startTime: Long): Seq[Node] = { val jobEventJsonAsStrSeq = makeJobEvent(jobs) @@ -217,7 +203,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { request: HttpServletRequest, tableHeaderId: String, jobTag: String, - jobs: Seq[JobUIData], + jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) @@ -258,14 +244,13 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { try { new JobPagedTable( + store, jobs, tableHeaderId, jobTag, UIUtils.prependBaseUri(parent.basePath), "jobs", // subPath parameterOtherTable, - parent.jobProgresslistener.stageIdToInfo, - parent.jobProgresslistener.stageIdToData, killEnabled, currentTime, jobIdTitle, @@ -285,106 +270,117 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } def render(request: HttpServletRequest): Seq[Node] = { - val listener = parent.jobProgresslistener - listener.synchronized { - val startTime = listener.startTime - val endTime = listener.endTime - val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse - val failedJobs = listener.failedJobs.reverse - - val activeJobsTable = - jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) - val completedJobsTable = - jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) - val failedJobsTable = - jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) - - val shouldShowActiveJobs = activeJobs.nonEmpty - val shouldShowCompletedJobs = completedJobs.nonEmpty - val shouldShowFailedJobs = failedJobs.nonEmpty - - val completedJobNumStr = if (completedJobs.size == listener.numCompletedJobs) { - s"${completedJobs.size}" - } else { - s"${listener.numCompletedJobs}, only showing ${completedJobs.size}" + val appInfo = store.applicationInfo() + val startTime = appInfo.attempts.head.startTime.getTime() + val endTime = appInfo.attempts.head.endTime.getTime() + + val activeJobs = new ListBuffer[v1.JobData]() + val completedJobs = new ListBuffer[v1.JobData]() + val failedJobs = new ListBuffer[v1.JobData]() + + store.jobsList(null).foreach { job => + job.status match { + case JobExecutionStatus.SUCCEEDED => + completedJobs += job + case JobExecutionStatus.FAILED => + failedJobs += job + case _ => + activeJobs += job } + } - val summary: NodeSeq = -
    -
      -
    • - User: - {parent.getSparkUser} -
    • -
    • - Total Uptime: - { - if (endTime < 0 && parent.sc.isDefined) { - UIUtils.formatDuration(System.currentTimeMillis() - startTime) - } else if (endTime > 0) { - UIUtils.formatDuration(endTime - startTime) - } - } -
    • -
    • - Scheduling Mode: - {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} -
    • + val activeJobsTable = + jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) + val completedJobsTable = + jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) + val failedJobsTable = + jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) + + val shouldShowActiveJobs = activeJobs.nonEmpty + val shouldShowCompletedJobs = completedJobs.nonEmpty + val shouldShowFailedJobs = failedJobs.nonEmpty + + val completedJobNumStr = s"${completedJobs.size}" + val schedulingMode = store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => SchedulingMode.withName(mode).toString } + .getOrElse("Unknown") + + val summary: NodeSeq = +
      +
        +
      • + User: + {parent.getSparkUser} +
      • +
      • + Total Uptime: { - if (shouldShowActiveJobs) { -
      • - Active Jobs: - {activeJobs.size} -
      • + if (endTime < 0 && parent.sc.isDefined) { + UIUtils.formatDuration(System.currentTimeMillis() - startTime) + } else if (endTime > 0) { + UIUtils.formatDuration(endTime - startTime) } } - { - if (shouldShowCompletedJobs) { -
      • - Completed Jobs: - {completedJobNumStr} -
      • - } + +
      • + Scheduling Mode: + {schedulingMode} +
      • + { + if (shouldShowActiveJobs) { +
      • + Active Jobs: + {activeJobs.size} +
      • } - { - if (shouldShowFailedJobs) { -
      • - Failed Jobs: - {listener.numFailedJobs} -
      • - } + } + { + if (shouldShowCompletedJobs) { +
      • + Completed Jobs: + {completedJobNumStr} +
      • } -
      -
      + } + { + if (shouldShowFailedJobs) { +
    • + Failed Jobs: + {failedJobs.size} +
    • + } + } +
    +
    - var content = summary - content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, - parent.parent.store.executorList(false), startTime) + var content = summary + content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, + store.executorList(false), startTime) - if (shouldShowActiveJobs) { - content ++=

    Active Jobs ({activeJobs.size})

    ++ - activeJobsTable - } - if (shouldShowCompletedJobs) { - content ++=

    Completed Jobs ({completedJobNumStr})

    ++ - completedJobsTable - } - if (shouldShowFailedJobs) { - content ++=

    Failed Jobs ({failedJobs.size})

    ++ - failedJobsTable - } + if (shouldShowActiveJobs) { + content ++=

    Active Jobs ({activeJobs.size})

    ++ + activeJobsTable + } + if (shouldShowCompletedJobs) { + content ++=

    Completed Jobs ({completedJobNumStr})

    ++ + completedJobsTable + } + if (shouldShowFailedJobs) { + content ++=

    Failed Jobs ({failedJobs.size})

    ++ + failedJobsTable + } - val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + - " Click on a job to see information about the stages of tasks inside it." + val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + + " Click on a job to see information about the stages of tasks inside it." - UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) - } + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) } + } private[ui] class JobTableRowData( - val jobData: JobUIData, + val jobData: v1.JobData, val lastStageName: String, val lastStageDescription: String, val duration: Long, @@ -395,9 +391,8 @@ private[ui] class JobTableRowData( val detailUrl: String) private[ui] class JobDataSource( - jobs: Seq[JobUIData], - stageIdToInfo: HashMap[Int, StageInfo], - stageIdToData: HashMap[(Int, Int), StageUIData], + store: AppStatusStore, + jobs: Seq[v1.JobData], basePath: String, currentTime: Long, pageSize: Int, @@ -418,40 +413,28 @@ private[ui] class JobDataSource( r } - private def getLastStageNameAndDescription(job: JobUIData): (String, String) = { - val lastStageInfo = Option(job.stageIds) - .filter(_.nonEmpty) - .flatMap { ids => stageIdToInfo.get(ids.max)} - val lastStageData = lastStageInfo.flatMap { s => - stageIdToData.get((s.stageId, s.attemptId)) - } - val name = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val description = lastStageData.flatMap(_.description).getOrElse("") - (name, description) - } - - private def jobRow(jobData: JobUIData): JobTableRowData = { - val (lastStageName, lastStageDescription) = getLastStageNameAndDescription(jobData) + private def jobRow(jobData: v1.JobData): JobTableRowData = { val duration: Option[Long] = { jobData.submissionTime.map { start => - val end = jobData.completionTime.getOrElse(System.currentTimeMillis()) - end - start + val end = jobData.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) + end - start.getTime() } } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) + val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), + basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) - new JobTableRowData ( + new JobTableRowData( jobData, - lastStageName, - lastStageDescription, + jobData.name, + jobData.description.getOrElse(jobData.name), duration.getOrElse(-1), formattedDuration, - submissionTime.getOrElse(-1), + submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, jobDescription, detailUrl @@ -479,15 +462,15 @@ private[ui] class JobDataSource( } } + private[ui] class JobPagedTable( - data: Seq[JobUIData], + store: AppStatusStore, + data: Seq[v1.JobData], tableHeaderId: String, jobTag: String, basePath: String, subPath: String, parameterOtherTable: Iterable[String], - stageIdToInfo: HashMap[Int, StageInfo], - stageIdToData: HashMap[(Int, Int), StageUIData], killEnabled: Boolean, currentTime: Long, jobIdTitle: String, @@ -510,9 +493,8 @@ private[ui] class JobPagedTable( override def pageNumberFormField: String = jobTag + ".page" override val dataSource = new JobDataSource( + store, data, - stageIdToInfo, - stageIdToData, basePath, currentTime, pageSize, @@ -624,15 +606,15 @@ private[ui] class JobPagedTable( {jobTableRow.formattedDuration} - {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {job.numCompletedStages}/{job.stageIds.size - job.numSkippedStages} {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} {UIUtils.makeProgressBar(started = job.numActiveTasks, - completed = job.completedIndices.size, + completed = job.numCompletedIndices, failed = job.numFailedTasks, skipped = job.numSkippedTasks, - reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} + reasonToNumKilled = job.killedTasksSummary, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index dc5b03c5269a9..e4cf99e7b9e04 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -22,120 +22,121 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, NodeSeq} import org.apache.spark.scheduler.Schedulable +import org.apache.spark.status.PoolData +import org.apache.spark.status.api.v1._ import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc - private val listener = parent.progressListener private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val activeStages = listener.activeStages.values.toSeq - val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse - val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse - val numFailedStages = listener.numFailedStages - val subPath = "stages" + val allStages = parent.store.stageList(null) - val activeStagesTable = - new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, - killEnabled = parent.killEnabled, isFailedStage = false) - val pendingStagesTable = - new StageTableBase(request, pendingStages, "pending", "pendingStage", parent.basePath, - subPath, parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val completedStagesTable = - new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, - subPath, parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val failedStagesTable = - new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, subPath, - parent.progressListener, parent.isFairScheduler, - killEnabled = false, isFailedStage = true) + val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) + val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) + val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) - val poolTable = new PoolTable(pools, parent) + val numCompletedStages = completedStages.size + val numFailedStages = failedStages.size + val subPath = "stages" - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val activeStagesTable = + new StageTableBase(parent.store, request, activeStages, "active", "activeStage", + parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) + val pendingStagesTable = + new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) + val completedStagesTable = + new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) + val failedStagesTable = + new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", + parent.basePath, subPath, parent.isFairScheduler, false, true) - val completedStageNumStr = if (numCompletedStages == completedStages.size) { - s"$numCompletedStages" - } else { - s"$numCompletedStages, only showing ${completedStages.size}" - } + // For now, pool information is only accessible in live UIs + val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => + val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( + new PoolData(pool.name, Set())) + pool -> uiPool + }.toMap + val poolTable = new PoolTable(pools, parent) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = pendingStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val completedStageNumStr = if (numCompletedStages == completedStages.size) { + s"$numCompletedStages" + } else { + s"$numCompletedStages, only showing ${completedStages.size}" + } - val summary: NodeSeq = -
    -
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } + val summary: NodeSeq = +
      +
        + { + if (shouldShowActiveStages) { +
      • + Active Stages: + {activeStages.size} +
      • } - { - if (shouldShowPendingStages) { -
      • - Pending Stages: - {pendingStages.size} -
      • - } + } + { + if (shouldShowPendingStages) { +
      • + Pending Stages: + {pendingStages.size} +
      • } - { - if (shouldShowCompletedStages) { -
      • - Completed Stages: - {completedStageNumStr} -
      • - } + } + { + if (shouldShowCompletedStages) { +
      • + Completed Stages: + {completedStageNumStr} +
      • } - { - if (shouldShowFailedStages) { -
      • - Failed Stages: - {numFailedStages} -
      • - } + } + { + if (shouldShowFailedStages) { +
      • + Failed Stages: + {numFailedStages} +
      • } -
      -
      - - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { -

      Fair Scheduler Pools ({pools.size})

      ++ poolTable.toNodeSeq - } else { - Seq.empty[Node] } +
    +
    + + var content = summary ++ + { + if (sc.isDefined && isFairScheduler) { +

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq + } else { + Seq.empty[Node] } - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq - } - if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingStages.size})

    ++ - pendingStagesTable.toNodeSeq - } - if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStageNumStr})

    ++ - completedStagesTable.toNodeSeq - } - if (shouldShowFailedStages) { - content ++=

    Failed Stages ({numFailedStages})

    ++ - failedStagesTable.toNodeSeq } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingStages.size})

    ++ + pendingStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStageNumStr})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({numFailedStages})

    ++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } } - diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 07a41d195a191..41d42b52430a5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -17,44 +17,19 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageData import org.apache.spark.ui.{ToolTips, UIUtils} -import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable( - stageId: Int, - stageAttemptId: Int, - parent: StagesTab, - store: AppStatusStore) { - private val listener = parent.progressListener +private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { - def toNodeSeq: Seq[Node] = { - listener.synchronized { - executorTable() - } - } - - /** Special table which merges two header cells. */ - private def executorTable[T](): Seq[Node] = { - val stageData = listener.stageIdToData.get((stageId, stageAttemptId)) - var hasInput = false - var hasOutput = false - var hasShuffleWrite = false - var hasShuffleRead = false - var hasBytesSpilled = false - stageData.foreach { data => - hasInput = data.hasInput - hasOutput = data.hasOutput - hasShuffleRead = data.hasShuffleRead - hasShuffleWrite = data.hasShuffleWrite - hasBytesSpilled = data.hasBytesSpilled - } + import ApiHelper._ + def toNodeSeq: Seq[Node] = { @@ -64,29 +39,29 @@ private[ui] class ExecutorTable( - {if (hasInput) { + {if (hasInput(stage)) { }} - {if (hasOutput) { + {if (hasOutput(stage)) { }} - {if (hasShuffleRead) { + {if (hasShuffleRead(stage)) { }} - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { }} - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { }} @@ -97,7 +72,7 @@ private[ui] class ExecutorTable( - {createExecutorTable()} + {createExecutorTable(stage)}
    Executor IDFailed Tasks Killed Tasks Succeeded Tasks Input Size / Records Output Size / Records Shuffle Read Size / Records Shuffle Write Size / Records Shuffle Spill (Memory) Shuffle Spill (Disk)
    } - private def createExecutorTable() : Seq[Node] = { - // Make an executor-id -> address map - val executorIdToAddress = mutable.HashMap[String, String]() - listener.blockManagerIds.foreach { blockManagerId => - val address = blockManagerId.hostPort - val executorId = blockManagerId.executorId - executorIdToAddress.put(executorId, address) - } - - listener.stageIdToData.get((stageId, stageAttemptId)) match { - case Some(stageData: StageUIData) => - stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => - - -
    {k}
    -
    - { - store.executorSummary(k).map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - } - } -
    - - {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} - {v.failedTasks} - {v.reasonToNumKilled.values.sum} - {v.succeededTasks} - {if (stageData.hasInput) { - - {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} - - }} - {if (stageData.hasOutput) { - - {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} - - }} - {if (stageData.hasShuffleRead) { - - {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} - - }} - {if (stageData.hasShuffleWrite) { - - {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} - - }} - {if (stageData.hasBytesSpilled) { - - {Utils.bytesToString(v.memoryBytesSpilled)} - - - {Utils.bytesToString(v.diskBytesSpilled)} - - }} - {v.isBlacklisted} - - } - case None => - Seq.empty[Node] + private def createExecutorTable(stage: StageData) : Seq[Node] = { + stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executor = store.asOption(store.executorSummary(k)) + + +
    {k}
    +
    + { + executor.map(_.executorLogs).getOrElse(Map.empty).map { + case (logName, logUrl) => + } + } +
    + + {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} + {UIUtils.formatDuration(v.taskTime)} + {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks} + {v.killedTasks} + {v.succeededTasks} + {if (hasInput(stage)) { + + {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} + + }} + {if (hasOutput(stage)) { + + {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} + + }} + {if (hasShuffleRead(stage)) { + + {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} + + }} + {if (hasShuffleWrite(stage)) { + + {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} + + }} + {if (hasBytesSpilled(stage)) { + + {Utils.bytesToString(v.memoryBytesSpilled)} + + + {Utils.bytesToString(v.diskBytesSpilled)} + + }} + {executor.map(_.isBlacklisted).getOrElse(false)} + } } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 7ed01646f3621..740f12e7d13d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.{Date, Locale} +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -27,11 +27,12 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler._ -import org.apache.spark.status.api.v1.ExecutorSummary -import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 +import org.apache.spark.ui._ /** Page showing statistics and stage list for a given job */ -private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { +private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("job") { private val STAGES_LEGEND =
    @@ -56,14 +57,15 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { Removed
    .toString.filter(_ != '\n') - private def makeStageEvent(stageInfos: Seq[StageInfo]): Seq[String] = { + private def makeStageEvent(stageInfos: Seq[v1.StageData]): Seq[String] = { stageInfos.map { stage => val stageId = stage.stageId val attemptId = stage.attemptId val name = stage.name - val status = stage.getStatusString - val submissionTime = stage.submissionTime.get - val completionTime = stage.completionTime.getOrElse(System.currentTimeMillis()) + val status = stage.status.toString + val submissionTime = stage.submissionTime.get.getTime() + val completionTime = stage.completionTime.map(_.getTime()) + .getOrElse(System.currentTimeMillis()) // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. @@ -79,10 +81,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'data-placement="top" data-html="true"' + | 'data-title="${jsEscapedName} (Stage ${stageId}.${attemptId})
    ' + | 'Status: ${status.toUpperCase(Locale.ROOT)}
    ' + - | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + + | 'Submitted: ${UIUtils.formatDate(submissionTime)}' + | '${ if (status != "running") { - s"""
    Completed: ${UIUtils.formatDate(new Date(completionTime))}""" + s"""
    Completed: ${UIUtils.formatDate(completionTime)}""" } else { "" } @@ -93,7 +95,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } } - def makeExecutorEvent(executors: Seq[ExecutorSummary]): Seq[String] = { + def makeExecutorEvent(executors: Seq[v1.ExecutorSummary]): Seq[String] = { val events = ListBuffer[String]() executors.foreach { e => val addedEvent = @@ -137,8 +139,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } private def makeTimeline( - stages: Seq[StageInfo], - executors: Seq[ExecutorSummary], + stages: Seq[v1.StageData], + executors: Seq[v1.ExecutorSummary], appStartTime: Long): Seq[Node] = { val stageEventJsonAsStrSeq = makeStageEvent(stages) @@ -182,173 +184,181 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { } def render(request: HttpServletRequest): Seq[Node] = { - val listener = parent.jobProgresslistener + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - listener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) - require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - - val jobId = parameterId.toInt - val jobDataOption = listener.jobIdToData.get(jobId) - if (jobDataOption.isEmpty) { - val content = -
    -

    No information to display for job {jobId}

    -
    - return UIUtils.headerSparkPage( - s"Details for Job $jobId", content, parent) - } - val jobData = jobDataOption.get - val isComplete = jobData.status != JobExecutionStatus.RUNNING - val stages = jobData.stageIds.map { stageId => - // This could be empty if the JobProgressListener hasn't received information about the - // stage or if the stage information has been garbage collected - listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) + val jobId = parameterId.toInt + val jobData = store.asOption(store.job(jobId)).getOrElse { + val content = +
    +

    No information to display for job {jobId}

    +
    + return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the listener hasn't received information about the + // stage or if the stage information has been garbage collected + store.stageData(stageId).lastOption.getOrElse { + new v1.StageData( + v1.StageStatus.PENDING, + stageId, + 0, 0, 0, 0, 0, 0, 0, + 0L, 0L, None, None, None, None, + 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, + "Unknown", + None, + "Unknown", + null, + Nil, + Nil, + None, + None, + Map()) } + } - val activeStages = Buffer[StageInfo]() - val completedStages = Buffer[StageInfo]() - // If the job is completed, then any pending stages are displayed as "skipped": - val pendingOrSkippedStages = Buffer[StageInfo]() - val failedStages = Buffer[StageInfo]() - for (stage <- stages) { - if (stage.submissionTime.isEmpty) { - pendingOrSkippedStages += stage - } else if (stage.completionTime.isDefined) { - if (stage.failureReason.isDefined) { - failedStages += stage - } else { - completedStages += stage - } + val activeStages = Buffer[v1.StageData]() + val completedStages = Buffer[v1.StageData]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = Buffer[v1.StageData]() + val failedStages = Buffer[v1.StageData]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.status == v1.StageStatus.FAILED) { + failedStages += stage } else { - activeStages += stage + completedStages += stage } + } else { + activeStages += stage } + } - val basePath = "jobs/job" + val basePath = "jobs/job" - val pendingOrSkippedTableId = - if (isComplete) { - "pending" - } else { - "skipped" - } + val pendingOrSkippedTableId = + if (isComplete) { + "pending" + } else { + "skipped" + } - val activeStagesTable = - new StageTableBase(request, activeStages, "active", "activeStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = parent.killEnabled, isFailedStage = false) - val pendingOrSkippedStagesTable = - new StageTableBase(request, pendingOrSkippedStages, pendingOrSkippedTableId, "pendingStage", - parent.basePath, basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val completedStagesTable = - new StageTableBase(request, completedStages, "completed", "completedStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = false) - val failedStagesTable = - new StageTableBase(request, failedStages, "failed", "failedStage", parent.basePath, - basePath, parent.jobProgresslistener, parent.isFairScheduler, - killEnabled = false, isFailedStage = true) + val activeStagesTable = + new StageTableBase(store, request, activeStages, "active", "activeStage", parent.basePath, + basePath, parent.isFairScheduler, + killEnabled = parent.killEnabled, isFailedStage = false) + val pendingOrSkippedStagesTable = + new StageTableBase(store, request, pendingOrSkippedStages, pendingOrSkippedTableId, + "pendingStage", parent.basePath, basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = false) + val completedStagesTable = + new StageTableBase(store, request, completedStages, "completed", "completedStage", + parent.basePath, basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = false) + val failedStagesTable = + new StageTableBase(store, request, failedStages, "failed", "failedStage", parent.basePath, + basePath, parent.isFairScheduler, + killEnabled = false, isFailedStage = true) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty - val summary: NodeSeq = -
    -
      -
    • - Status: - {jobData.status} -
    • - { - if (jobData.jobGroup.isDefined) { -
    • - Job Group: - {jobData.jobGroup.get} -
    • - } - } - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - - Pending Stages: - {pendingOrSkippedStages.size} -
    • - } + val summary: NodeSeq = +
      +
        +
      • + Status: + {jobData.status} +
      • + { + if (jobData.jobGroup.isDefined) { +
      • + Job Group: + {jobData.jobGroup.get} +
      • } - { - if (shouldShowCompletedStages) { -
      • - Completed Stages: - {completedStages.size} -
      • - } + } + { + if (shouldShowActiveStages) { +
      • + Active Stages: + {activeStages.size} +
      • } - { - if (shouldShowSkippedStages) { + } + { + if (shouldShowPendingStages) {
      • - Skipped Stages: - {pendingOrSkippedStages.size} + + Pending Stages: + {pendingOrSkippedStages.size}
      • } + } + { + if (shouldShowCompletedStages) { +
      • + Completed Stages: + {completedStages.size} +
      • } - { - if (shouldShowFailedStages) { -
      • - Failed Stages: - {failedStages.size} -
      • - } + } + { + if (shouldShowSkippedStages) { +
      • + Skipped Stages: + {pendingOrSkippedStages.size} +
      • + } + } + { + if (shouldShowFailedStages) { +
      • + Failed Stages: + {failedStages.size} +
      • } -
      -
      + } +
    +
    - var content = summary - val appStartTime = listener.startTime - val operationGraphListener = parent.operationGraphListener + var content = summary + val appStartTime = store.applicationInfo().attempts.head.startTime.getTime() - content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, - parent.parent.store.executorList(false), appStartTime) + content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, + store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, operationGraphListener.getOperationGraphForJob(jobId)) + content ++= UIUtils.showDagVizForJob( + jobId, store.operationGraphForJob(jobId)) - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq - } - if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq - } - if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq - } - if (shouldShowSkippedStages) { - content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq - } - if (shouldShowFailedStages) { - content ++=

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq - } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStages.size})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({failedStages.size})

    ++ + failedStagesTable.toNodeSeq } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 81ffe04aca49a..99eab1b2a27d8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -19,35 +19,45 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest +import scala.collection.JavaConverters._ + +import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} +import org.apache.spark.status.AppStatusStore +import org.apache.spark.ui._ /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobsTab(val parent: SparkUI) extends SparkUITab(parent, "jobs") { +private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) + extends SparkUITab(parent, "jobs") { + val sc = parent.sc val killEnabled = parent.killEnabled - val jobProgresslistener = parent.jobProgressListener - val operationGraphListener = parent.operationGraphListener - def isFairScheduler: Boolean = - jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR) + def isFairScheduler: Boolean = { + store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => mode == SchedulingMode.FAIR } + .getOrElse(false) + } def getSparkUser: String = parent.getSparkUser - attachPage(new AllJobsPage(this)) - attachPage(new JobPage(this)) + attachPage(new AllJobsPage(this, store)) + attachPage(new JobPage(this, store)) def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { // stripXSS is called first to remove suspicious characters used in XSS attacks val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) jobId.foreach { id => - if (jobProgresslistener.activeJobs.contains(id)) { - sc.foreach(_.cancelJob(id)) - // Do a quick pause here to give Spark time to kill the job so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) + store.asOption(store.job(id)).foreach { job => + if (job.status == JobExecutionStatus.RUNNING) { + sc.foreach(_.cancelJob(id)) + // Do a quick pause here to give Spark time to kill the job so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 4b8c7b203771d..98fbd7aceaa11 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -21,46 +21,39 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.status.PoolData +import org.apache.spark.status.api.v1._ import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing specific pool details */ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { - private val sc = parent.sc - private val listener = parent.progressListener def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => - UIUtils.decodeURLParameter(poolname) - }.getOrElse { - throw new IllegalArgumentException(s"Missing poolname parameter") - } - - val poolToActiveStages = listener.poolToActiveStages - val activeStages = poolToActiveStages.get(poolName) match { - case Some(s) => s.values.toSeq - case None => Seq.empty[StageInfo] - } - val shouldShowActiveStages = activeStages.nonEmpty - val activeStagesTable = - new StageTableBase(request, activeStages, "", "activeStage", parent.basePath, "stages/pool", - parent.progressListener, parent.isFairScheduler, parent.killEnabled, - isFailedStage = false) - - // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getPoolForName(poolName).getOrElse { - throw new IllegalArgumentException(s"Unknown poolname: $poolName") - }).toSeq - val poolTable = new PoolTable(pools, parent) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => + UIUtils.decodeURLParameter(poolname) + }.getOrElse { + throw new IllegalArgumentException(s"Missing poolname parameter") + } - var content =

    Summary

    ++ poolTable.toNodeSeq - if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq - } + // For now, pool information is only accessible in live UIs + val pool = parent.sc.flatMap(_.getPoolForName(poolName)).getOrElse { + throw new IllegalArgumentException(s"Unknown pool: $poolName") + } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + val uiPool = parent.store.asOption(parent.store.pool(poolName)).getOrElse( + new PoolData(poolName, Set())) + val activeStages = uiPool.stageIds.toSeq.map(parent.store.lastStageAttempt(_)) + val activeStagesTable = + new StageTableBase(parent.store, request, activeStages, "", "activeStage", parent.basePath, + "stages/pool", parent.isFairScheduler, parent.killEnabled, false) + + val poolTable = new PoolTable(Map(pool -> uiPool), parent) + var content =

    Summary

    ++ poolTable.toNodeSeq + if (activeStages.nonEmpty) { + content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq } + + UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index ea02968733cac..5dfce858dec07 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -19,25 +19,16 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder -import scala.collection.mutable.HashMap import scala.xml.Node -import org.apache.spark.scheduler.{Schedulable, StageInfo} +import org.apache.spark.scheduler.Schedulable +import org.apache.spark.status.PoolData import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { - private val listener = parent.progressListener +private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { def toNodeSeq: Seq[Node] = { - listener.synchronized { - poolTable(poolRow, pools) - } - } - - private def poolTable( - makeRow: (Schedulable, HashMap[String, HashMap[Int, StageInfo]]) => Seq[Node], - rows: Seq[Schedulable]): Seq[Node] = { @@ -48,29 +39,24 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { - {rows.map(r => makeRow(r, listener.poolToActiveStages))} + {pools.map { case (s, p) => poolRow(s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow( - p: Schedulable, - poolToActiveStages: HashMap[String, HashMap[Int, StageInfo]]): Seq[Node] = { - val activeStages = poolToActiveStages.get(p.name) match { - case Some(stages) => stages.size - case None => 0 - } + private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} - {p.minShare} - {p.weight} + {s.minShare} + {s.weight} {activeStages} - {p.runningTasks} - {p.schedulingMode} + {s.runningTasks} + {s.schedulingMode} } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 3151b8d554658..5f93f2ffb412f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -21,26 +21,25 @@ import java.net.URLEncoder import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashMap, HashSet} import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { + import ApiHelper._ import StagePage._ - private val progressListener = parent.progressListener - private val operationGraphListener = parent.operationGraphListener - private val TIMELINE_LEGEND = {
    @@ -69,555 +68,521 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageUIData): String = { - val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) + private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { + val localities = taskList.map(_.taskLocality) val localityCounts = localities.groupBy(identity).mapValues(_.size) + val names = Map( + TaskLocality.PROCESS_LOCAL.toString() -> "Process local", + TaskLocality.NODE_LOCAL.toString() -> "Node local", + TaskLocality.RACK_LOCAL.toString() -> "Rack local", + TaskLocality.ANY.toString() -> "Any") val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - val localityName = locality match { - case TaskLocality.PROCESS_LOCAL => "Process local" - case TaskLocality.NODE_LOCAL => "Node local" - case TaskLocality.RACK_LOCAL => "Rack local" - case TaskLocality.ANY => "Any" - } - s"$localityName: $count" + s"${names(locality)}: $count" } localityNamesAndCounts.sorted.mkString("; ") } def render(request: HttpServletRequest): Seq[Node] = { - progressListener.synchronized { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) - require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - - val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) - require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - - val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) - val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) - val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) - val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) - val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) - - val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) - val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => - UIUtils.decodeURLParameter(sortColumn) - }.getOrElse("Index") - val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) - val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) - val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) - - val stageId = parameterId.toInt - val stageAttemptId = parameterAttempt.toInt - val stageDataOption = progressListener.stageIdToData.get((stageId, stageAttemptId)) - - val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" - if (stageDataOption.isEmpty) { + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) + require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + + val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) + val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) + val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) + val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) + + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) + val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Index") + val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) + val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) + + val stageId = parameterId.toInt + val stageAttemptId = parameterAttempt.toInt + + val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" + val stageData = parent.store + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .getOrElse { val content =

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    return UIUtils.headerSparkPage(stageHeader, content, parent) - - } - if (stageDataOption.get.taskData.isEmpty) { - val content = -
    -

    Summary Metrics

    No tasks have started yet -

    Tasks

    No tasks have started yet -
    - return UIUtils.headerSparkPage(stageHeader, content, parent) } - val stageData = stageDataOption.get - val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + - stageData.numCompleteTasks + stageData.numFailedTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { - s"$totalTasks" - } else { - s"$totalTasks, showing ${tasks.size}" - } + val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq + if (tasks.isEmpty) { + val content = +
    +

    Summary Metrics

    No tasks have started yet +

    Tasks

    No tasks have started yet +
    + return UIUtils.headerSparkPage(stageHeader, content, parent) + } - val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.nonEmpty + val numCompleted = stageData.numCompleteTasks + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + val totalTasksNumStr = if (totalTasks == tasks.size) { + s"$totalTasks" + } else { + s"$totalTasks, showing ${tasks.size}" + } - val summary = -
    -
      + val externalAccumulables = stageData.accumulatorUpdates + val hasAccumulators = externalAccumulables.size > 0 + + val summary = +
      +
        +
      • + Total Time Across All Tasks: + {UIUtils.formatDuration(stageData.executorRunTime)} +
      • +
      • + Locality Level Summary: + {getLocalitySummaryString(stageData, tasks)} +
      • + {if (hasInput(stageData)) {
      • - Total Time Across All Tasks: - {UIUtils.formatDuration(stageData.executorRunTime)} + Input Size / Records: + {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"}
      • + }} + {if (hasOutput(stageData)) {
      • - Locality Level Summary: - {getLocalitySummaryString(stageData)} + Output: + {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
      • - {if (stageData.hasInput) { -
      • - Input Size / Records: - {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"} -
      • - }} - {if (stageData.hasOutput) { -
      • - Output: - {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"} -
      • - }} - {if (stageData.hasShuffleRead) { -
      • - Shuffle Read: - {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + - s"${stageData.shuffleReadRecords}"} -
      • - }} - {if (stageData.hasShuffleWrite) { -
      • - Shuffle Write: - {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + - s"${stageData.shuffleWriteRecords}"} -
      • - }} - {if (stageData.hasBytesSpilled) { -
      • - Shuffle Spill (Memory): - {Utils.bytesToString(stageData.memoryBytesSpilled)} -
      • -
      • - Shuffle Spill (Disk): - {Utils.bytesToString(stageData.diskBytesSpilled)} -
      • - }} -
      -
      + }} + {if (hasShuffleRead(stageData)) { +
    • + Shuffle Read: + {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + + s"${stageData.shuffleReadRecords}"} +
    • + }} + {if (hasShuffleWrite(stageData)) { +
    • + Shuffle Write: + {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + + s"${stageData.shuffleWriteRecords}"} +
    • + }} + {if (hasBytesSpilled(stageData)) { +
    • + Shuffle Spill (Memory): + {Utils.bytesToString(stageData.memoryBytesSpilled)} +
    • +
    • + Shuffle Spill (Disk): + {Utils.bytesToString(stageData.diskBytesSpilled)} +
    • + }} +
    +
    - val showAdditionalMetrics = -
    - - - Show Additional Metrics - - +
    - val outputSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.bytesWritten.toDouble - } + val stageGraph = parent.store.asOption(parent.store.operationGraphForStage(stageId)) + val dagViz = UIUtils.showDagVizForStage(stageId, stageGraph) - val outputRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.recordsWritten.toDouble - } + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo): Seq[Node] = { + {acc.name}{acc.value} + } + val accumulableTable = UIUtils.listingTable( + accumulableHeaders, + accumulableRow, + externalAccumulables.toSeq) + + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (taskPageSize <= taskPrevPageSize) { + taskPage + } else { + 1 + } + } + val currentTime = System.currentTimeMillis() + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + parent.conf, + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, + hasAccumulators, + hasInput(stageData), + hasOutput(stageData), + hasShuffleRead(stageData), + hasShuffleWrite(stageData), + hasBytesSpilled(stageData), + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc, + store = parent.store + ) + (_taskTable, _taskTable.table(page)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + val errorMessage = +
    +

    Error while rendering stage table:

    +
    +              {Utils.exceptionString(e)}
    +            
    +
    + (null, errorMessage) + } - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + val jsForScrollingDownToTaskTable = + - val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds - val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.totalBytesRead.toDouble - } - val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + // Excludes tasks which failed and have incomplete metrics + val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.remoteBytesRead.toDouble + val summaryTable: Option[Seq[Node]] = + if (validTasks.size == 0) { + None + } else { + def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { + Distribution(data).get.getQuantiles() + } + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { + getDistributionQuantiles(times).map { millis => + {UIUtils.formatDuration(millis.toLong)} } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) + } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } - val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationTimes = validTasks.map { task => + task.taskMetrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + + + Task Deserialization Time + + +: getFormattedTimeQuantiles(deserializationTimes) + + val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) + val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) + val gcQuantiles = + + GC Time + + +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) + val serializationQuantiles = + + + Result Serialization Time + + +: getFormattedTimeQuantiles(serializationTimes) + + val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) + val gettingResultQuantiles = + + + Getting Result Time + + +: + getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } - val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map { task => + getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble + } + val schedulerDelayTitle = Scheduler Delay + val schedulerDelayQuantiles = schedulerDelayTitle +: + getFormattedTimeQuantiles(schedulerDelays) + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { + val recordDist = getDistributionQuantiles(records).iterator + getDistributionQuantiles(data).map(d => + {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} + ) + } - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) + val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) + val inputQuantiles = Input Size / Records +: + getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) - val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.memoryBytesSpilled.toDouble - } - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) + val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) + val outputQuantiles = Output Size / Records +: + getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) - val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.diskBytesSpilled.toDouble - } - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) - - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (stageData.hasInput) {inputQuantiles} else Nil, - if (stageData.hasOutput) {outputQuantiles} else Nil, - if (stageData.hasShuffleRead) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (stageData.hasShuffleWrite) {shuffleWriteQuantiles} else Nil, - if (stageData.hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, - if (stageData.hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val shuffleReadBlockedTimes = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble + } + val shuffleReadBlockedQuantiles = + + + Shuffle Read Blocked Time + + +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + + val shuffleReadTotalSizes = validTasks.map { task => + totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble + } + val shuffleReadTotalRecords = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble + } + val shuffleReadTotalQuantiles = + + + Shuffle Read Size / Records + + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + + val shuffleReadRemoteSizes = validTasks.map { task => + task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble + } + val shuffleReadRemoteQuantiles = + + + Shuffle Remote Reads + + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) + + val shuffleWriteSizes = validTasks.map { task => + task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble } - val executorTable = new ExecutorTable(stageId, stageAttemptId, parent, store) + val shuffleWriteRecords = validTasks.map { task => + task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble + } - val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq.empty + val shuffleWriteQuantiles = Shuffle Write Size / Records +: + getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + + val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) + val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: + getFormattedSizeQuantiles(memoryBytesSpilledSizes) + + val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) + val diskBytesSpilledQuantiles = Shuffle spill (disk) +: + getFormattedSizeQuantiles(diskBytesSpilledSizes) + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", + "Median", "75th percentile", "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + Some(UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false)) + } - val aggMetrics = - -

    - - Aggregated Metrics by Executor -

    -
    -
    - {executorTable.toNodeSeq} -
    + val executorTable = new ExecutorTable(stageData, parent.store) + + val maybeAccumulableTable: Seq[Node] = + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + + val aggMetrics = + +

    + + Aggregated Metrics by Executor +

    +
    +
    + {executorTable.toNodeSeq} +
    - val content = - summary ++ - dagViz ++ - showAdditionalMetrics ++ - makeTimeline( - // Only show the tasks in the table - stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), - currentTime) ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ - aggMetrics ++ - maybeAccumulableTable ++ -

    Tasks ({totalTasksNumStr})

    ++ - taskTableHTML ++ jsForScrollingDownToTaskTable - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) - } + val content = + summary ++ + dagViz ++ + showAdditionalMetrics ++ + makeTimeline( + // Only show the tasks in the table + tasks.filter { t => taskIdsInPage.contains(t.taskId) }, + currentTime) ++ +

    Summary Metrics for {numCompleted} Completed Tasks

    ++ +
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ + aggMetrics ++ + maybeAccumulableTable ++ +

    Tasks ({totalTasksNumStr})

    ++ + taskTableHTML ++ jsForScrollingDownToTaskTable + UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } - def makeTimeline(tasks: Seq[TaskUIData], currentTime: Long): Seq[Node] = { + def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { val executorsSet = new HashSet[(String, String)] var minLaunchTime = Long.MaxValue var maxFinishTime = Long.MinValue val executorsArrayStr = - tasks.sortBy(-_.taskInfo.launchTime).take(MAX_TIMELINE_TASKS).map { taskUIData => - val taskInfo = taskUIData.taskInfo + tasks.sortBy(-_.launchTime.getTime()).take(MAX_TIMELINE_TASKS).map { taskInfo => val executorId = taskInfo.executorId val host = taskInfo.host executorsSet += ((executorId, host)) - val launchTime = taskInfo.launchTime - val finishTime = if (!taskInfo.running) taskInfo.finishTime else currentTime + val launchTime = taskInfo.launchTime.getTime() + val finishTime = taskInfo.duration.map(taskInfo.launchTime.getTime() + _) + .getOrElse(currentTime) val totalExecutionTime = finishTime - launchTime minLaunchTime = launchTime.min(minLaunchTime) maxFinishTime = finishTime.max(maxFinishTime) def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 - val metricsOpt = taskUIData.metrics + val metricsOpt = taskInfo.taskMetrics val shuffleReadTime = metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) @@ -629,14 +594,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) + val gettingResultTime = getGettingResultTime(taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) val schedulerDelay = metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime - val executorRunTime = if (taskInfo.running) { + val executorRunTime = if (taskInfo.duration.isDefined) { totalExecutionTime - executorOverhead - gettingResultTime } else { metricsOpt.map(_.executorRunTime).getOrElse( @@ -663,7 +628,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attemptNumber + val attempt = taskInfo.attempt val svgTag = if (totalExecutionTime == 0) { @@ -705,7 +670,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We |Status: ${taskInfo.status}
    |Launch Time: ${UIUtils.formatDate(new Date(launchTime))} |${ - if (!taskInfo.running) { + if (!taskInfo.duration.isDefined) { s"""
    Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" } else { "" @@ -770,34 +735,40 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { - if (info.gettingResult) { - if (info.finished) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - currentTime - info.gettingResultTime - } - } else { - 0L + private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { + info.resultFetchStart match { + case Some(start) => + info.duration match { + case Some(duration) => + info.launchTime.getTime() + duration - start.getTime() + + case _ => + currentTime - start.getTime() + } + + case _ => + 0L } } private[ui] def getSchedulerDelay( - info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { - if (info.finished) { - val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = metrics.executorDeserializeTime + - metrics.resultSerializationTime - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - } else { - // The task is still running and the metrics like executorRunTime are not available. - 0L + info: TaskData, + metrics: TaskMetrics, + currentTime: Long): Long = { + info.duration match { + case Some(duration) => + val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime + math.max( + 0, + duration - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + + case _ => + // The task is still running and the metrics like executorRunTime are not available. + 0L } } + } private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) @@ -826,7 +797,7 @@ private[ui] case class TaskTableRowBytesSpilledData( /** * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskUIData to avoid creating duplicate contents during sorting the data. + * TaskData to avoid creating duplicate contents during sorting the data. */ private[ui] class TaskTableRowData( val index: Int, @@ -856,14 +827,13 @@ private[ui] class TaskTableRowData( val logs: Map[String, String]) private[ui] class TaskDataSource( - tasks: Seq[TaskUIData], + tasks: Seq[TaskData], hasAccumulators: Boolean, hasInput: Boolean, hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean, - lastUpdateTime: Option[Long], currentTime: Long, pageSize: Int, sortColumn: String, @@ -871,7 +841,10 @@ private[ui] class TaskDataSource( store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { import StagePage._ - // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // Keep an internal cache of executor log maps so that long task lists render faster. + private val executorIdToLogs = new HashMap[String, Map[String, String]]() + + // Convert TaskData to TaskTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) @@ -887,26 +860,19 @@ private[ui] class TaskDataSource( def slicedTaskIds: Set[Long] = _slicedTaskIds - private def taskRow(taskData: TaskUIData): TaskTableRowData = { - val info = taskData.taskInfo - val metrics = taskData.metrics - val duration = taskData.taskDuration(lastUpdateTime).getOrElse(1L) - val formatDuration = - taskData.taskDuration(lastUpdateTime).map(d => UIUtils.formatDuration(d)).getOrElse("") + private def taskRow(info: TaskData): TaskTableRowData = { + val metrics = info.taskMetrics + val duration = info.duration.getOrElse(1L) + val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val externalAccumulableReadable = info.accumulables - .filterNot(_.internal) - .flatMap { a => - (a.name, a.update) match { - case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update")) - case _ => None - } - } + val externalAccumulableReadable = info.accumulatorUpdates.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + } val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) val maybeInput = metrics.map(_.inputMetrics) @@ -928,7 +894,7 @@ private[ui] class TaskDataSource( val shuffleReadBlockedTimeReadable = maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") @@ -1011,17 +977,16 @@ private[ui] class TaskDataSource( None } - val logs = store.executorSummary(info.executorId).map(_.executorLogs).getOrElse(Map.empty) new TaskTableRowData( info.index, info.taskId, - info.attemptNumber, + info.attempt, info.speculative, info.status, info.taskLocality.toString, info.executorId, info.host, - info.launchTime, + info.launchTime.getTime(), duration, formatDuration, schedulerDelay, @@ -1036,8 +1001,13 @@ private[ui] class TaskDataSource( shuffleRead, shuffleWrite, bytesSpilled, - taskData.errorMessage.getOrElse(""), - logs) + info.errorMessage.getOrElse(""), + executorLogs(info.executorId)) + } + + private def executorLogs(id: String): Map[String, String] = { + executorIdToLogs.getOrElseUpdate(id, + store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } /** @@ -1148,14 +1118,13 @@ private[ui] class TaskDataSource( private[ui] class TaskPagedTable( conf: SparkConf, basePath: String, - data: Seq[TaskUIData], + data: Seq[TaskData], hasAccumulators: Boolean, hasInput: Boolean, hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean, - lastUpdateTime: Option[Long], currentTime: Long, pageSize: Int, sortColumn: String, @@ -1181,7 +1150,6 @@ private[ui] class TaskPagedTable( hasShuffleRead, hasShuffleWrite, hasBytesSpilled, - lastUpdateTime, currentTime, pageSize, sortColumn, @@ -1363,3 +1331,23 @@ private[ui] class TaskPagedTable( {errorSummary}{details} } } + +private object ApiHelper { + + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 + + def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 + + def hasShuffleRead(stageData: StageData): Boolean = stageData.shuffleReadBytes > 0 + + def hasShuffleWrite(stageData: StageData): Boolean = stageData.shuffleWriteBytes > 0 + + def hasBytesSpilled(stageData: StageData): Boolean = { + stageData.diskBytesSpilled > 0 || stageData.memoryBytesSpilled > 0 + } + + def totalBytesRead(metrics: ShuffleReadMetrics): Long = { + metrics.localBytesRead + metrics.remoteBytesRead + } + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f0a12a28de069..18a4926f2f6c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -26,19 +26,19 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1 import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils private[ui] class StageTableBase( + store: AppStatusStore, request: HttpServletRequest, - stages: Seq[StageInfo], + stages: Seq[v1.StageData], tableHeaderID: String, stageTag: String, basePath: String, subPath: String, - progressListener: JobProgressListener, isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { @@ -79,12 +79,12 @@ private[ui] class StageTableBase( val toNodeSeq = try { new StagePagedTable( + store, stages, tableHeaderID, stageTag, basePath, subPath, - progressListener, isFairScheduler, killEnabled, currentTime, @@ -106,13 +106,13 @@ private[ui] class StageTableBase( } private[ui] class StageTableRowData( - val stageInfo: StageInfo, - val stageData: Option[StageUIData], + val stage: v1.StageData, + val option: Option[v1.StageData], val stageId: Int, val attemptId: Int, val schedulingPool: String, val descriptionOption: Option[String], - val submissionTime: Long, + val submissionTime: Date, val formattedSubmissionTime: String, val duration: Long, val formattedDuration: String, @@ -126,19 +126,20 @@ private[ui] class StageTableRowData( val shuffleWriteWithUnit: String) private[ui] class MissingStageTableRowData( - stageInfo: StageInfo, + stageInfo: v1.StageData, stageId: Int, attemptId: Int) extends StageTableRowData( - stageInfo, None, stageId, attemptId, "", None, 0, "", -1, "", 0, "", 0, "", 0, "", 0, "") + stageInfo, None, stageId, attemptId, "", None, new Date(0), "", -1, "", 0, "", 0, "", 0, "", 0, + "") /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( - stages: Seq[StageInfo], + store: AppStatusStore, + stages: Seq[v1.StageData], tableHeaderId: String, stageTag: String, basePath: String, subPath: String, - listener: JobProgressListener, isFairScheduler: Boolean, killEnabled: Boolean, currentTime: Long, @@ -164,8 +165,8 @@ private[ui] class StagePagedTable( parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( + store, stages, - listener, currentTime, pageSize, sortColumn, @@ -274,10 +275,10 @@ private[ui] class StagePagedTable( } private def rowContent(data: StageTableRowData): Seq[Node] = { - data.stageData match { + data.option match { case None => missingStageRow(data.stageId) case Some(stageData) => - val info = data.stageInfo + val info = data.stage {if (data.attemptId > 0) { {data.stageId} (retry {data.attemptId}) @@ -301,8 +302,8 @@ private[ui] class StagePagedTable( {data.formattedDuration} {UIUtils.makeProgressBar(started = stageData.numActiveTasks, - completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} + completed = stageData.numCompleteTasks, failed = stageData.numFailedTasks, + skipped = 0, reasonToNumKilled = stageData.killedTasksSummary, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} @@ -318,7 +319,7 @@ private[ui] class StagePagedTable( } } - private def failureReasonHtml(s: StageInfo): Seq[Node] = { + private def failureReasonHtml(s: v1.StageData): Seq[Node] = { val failureReason = s.failureReason.getOrElse("") val isMultiline = failureReason.indexOf('\n') >= 0 // Display the first line by default @@ -344,7 +345,7 @@ private[ui] class StagePagedTable( {failureReasonSummary}{details} } - private def makeDescription(s: StageInfo, descriptionOption: Option[String]): Seq[Node] = { + private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { val basePathUri = UIUtils.prependBaseUri(basePath) val killLink = if (killEnabled) { @@ -368,8 +369,8 @@ private[ui] class StagePagedTable( val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} - val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) - val details = if (s.details.nonEmpty) { + val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } + val details = if (s.details != null && s.details.nonEmpty) { +details @@ -404,14 +405,14 @@ private[ui] class StagePagedTable( } private[ui] class StageDataSource( - stages: Seq[StageInfo], - listener: JobProgressListener, + store: AppStatusStore, + stages: Seq[v1.StageData], currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[StageTableRowData](pageSize) { - // Convert StageInfo to StageTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data + // Convert v1.StageData to StageTableRowData which contains the final contents to show in the + // table so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) private var _slicedStageIds: Set[Int] = _ @@ -424,57 +425,46 @@ private[ui] class StageDataSource( r } - private def stageRow(s: StageInfo): StageTableRowData = { - val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId)) + private def stageRow(stageData: v1.StageData): StageTableRowData = { + val description = stageData.description.getOrElse("") - if (stageDataOption.isEmpty) { - return new MissingStageTableRowData(s, s.stageId, s.attemptId) - } - val stageData = stageDataOption.get - - val description = stageData.description - - val formattedSubmissionTime = s.submissionTime match { - case Some(t) => UIUtils.formatDate(new Date(t)) + val formattedSubmissionTime = stageData.submissionTime match { + case Some(t) => UIUtils.formatDate(t) case None => "Unknown" } - val finishTime = s.completionTime.getOrElse(currentTime) + val finishTime = stageData.completionTime.map(_.getTime()).getOrElse(currentTime) // The submission time for a stage is misleading because it counts the time // the stage waits to be launched. (SPARK-10930) - val taskLaunchTimes = - stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) - val duration: Option[Long] = - if (taskLaunchTimes.nonEmpty) { - val startTime = taskLaunchTimes.min - if (finishTime > startTime) { - Some(finishTime - startTime) - } else { - Some(currentTime - startTime) - } + val duration = stageData.firstTaskLaunchedTime.map { date => + val time = date.getTime() + if (finishTime > time) { + finishTime - time } else { None + currentTime - time } + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadTotalBytes + val shuffleRead = stageData.shuffleReadBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" new StageTableRowData( - s, - stageDataOption, - s.stageId, - s.attemptId, + stageData, + Some(stageData), + stageData.stageId, + stageData.attemptId, stageData.schedulingPool, - description, - s.submissionTime.getOrElse(0), + stageData.description, + stageData.submissionTime.getOrElse(new Date(0)), formattedSubmissionTime, duration.getOrElse(-1), formattedDuration, @@ -496,7 +486,7 @@ private[ui] class StageDataSource( val ordering: Ordering[StageTableRowData] = sortColumn match { case "Stage Id" => Ordering.by(_.stageId) case "Pool Name" => Ordering.by(_.schedulingPool) - case "Description" => Ordering.by(x => (x.descriptionOption, x.stageInfo.name)) + case "Description" => Ordering.by(x => (x.descriptionOption, x.stage.name)) case "Submitted" => Ordering.by(_.submissionTime) case "Duration" => Ordering.by(_.duration) case "Input" => Ordering.by(_.inputRead) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 65446f967ad76..be05a963f0e68 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -21,36 +21,42 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageStatus import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ -private[ui] class StagesTab(val parent: SparkUI, store: AppStatusStore) +private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) extends SparkUITab(parent, "stages") { val sc = parent.sc val conf = parent.conf val killEnabled = parent.killEnabled - val progressListener = parent.jobProgressListener - val operationGraphListener = parent.operationGraphListener - val lastUpdateTime = parent.lastUpdateTime attachPage(new AllStagesPage(this)) attachPage(new StagePage(this, store)) attachPage(new PoolPage(this)) - def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR) + def isFairScheduler: Boolean = { + store.environmentInfo().sparkProperties.toMap + .get("spark.scheduler.mode") + .map { mode => mode == SchedulingMode.FAIR } + .getOrElse(false) + } def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { // stripXSS is called first to remove suspicious characters used in XSS attacks val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) stageId.foreach { id => - if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id, "killed via the Web UI")) - // Do a quick pause here to give Spark time to kill the stage so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) + store.asOption(store.lastStageAttempt(id)).foreach { stage => + val status = stage.status + if (status == StageStatus.ACTIVE || status == StageStatus.PENDING) { + sc.foreach(_.cancelStage(id, "killed via the Web UI")) + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index bb763248cd7e0..827a8637b9bd2 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -35,20 +35,20 @@ import org.apache.spark.storage.StorageLevel * nodes and children clusters. Additionally, a graph may also have edges that enter or exit * the graph from nodes that belong to adjacent graphs. */ -private[ui] case class RDDOperationGraph( +private[spark] case class RDDOperationGraph( edges: Seq[RDDOperationEdge], outgoingEdges: Seq[RDDOperationEdge], incomingEdges: Seq[RDDOperationEdge], rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) +private[spark] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. * This represents an RDD dependency. */ -private[ui] case class RDDOperationEdge(fromId: Int, toId: Int) +private[spark] case class RDDOperationEdge(fromId: Int, toId: Int) /** * A cluster that groups nodes together in an RDDOperationGraph. @@ -56,7 +56,7 @@ private[ui] case class RDDOperationEdge(fromId: Int, toId: Int) * This represents any grouping of RDDs, including operation scopes (e.g. textFile, flatMap), * stages, jobs, or any higher level construct. A cluster may be nested inside of other clusters. */ -private[ui] class RDDOperationCluster(val id: String, private var _name: String) { +private[spark] class RDDOperationCluster(val id: String, private var _name: String) { private val _childNodes = new ListBuffer[RDDOperationNode] private val _childClusters = new ListBuffer[RDDOperationCluster] @@ -92,7 +92,7 @@ private[ui] class RDDOperationCluster(val id: String, private var _name: String) } } -private[ui] object RDDOperationGraph extends Logging { +private[spark] object RDDOperationGraph extends Logging { val STAGE_CLUSTER_PREFIX = "stage_" diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala deleted file mode 100644 index 37a12a8646938..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.scope - -import scala.collection.mutable - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler._ -import org.apache.spark.ui.SparkUI - -/** - * A SparkListener that constructs a DAG of RDD operations. - */ -private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListener { - - // Note: the fate of jobs and stages are tied. This means when we clean up a job, - // we always clean up all of its stages. Similarly, when we clean up a stage, we - // always clean up its job (and, transitively, other stages in the same job). - private[ui] val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]] - private[ui] val jobIdToSkippedStageIds = new mutable.HashMap[Int, Seq[Int]] - private[ui] val stageIdToJobId = new mutable.HashMap[Int, Int] - private[ui] val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph] - private[ui] val completedStageIds = new mutable.HashSet[Int] - - // Keep track of the order in which these are inserted so we can remove old ones - private[ui] val jobIds = new mutable.ArrayBuffer[Int] - private[ui] val stageIds = new mutable.ArrayBuffer[Int] - - // How many root nodes to retain in DAG Graph - private[ui] val retainedNodes = - conf.getInt("spark.ui.dagGraph.retainedRootRDDs", Int.MaxValue) - - // How many jobs or stages to retain graph metadata for - private val retainedJobs = - conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) - private val retainedStages = - conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) - - /** - * Return the graph metadata for all stages in the given job. - * An empty list is returned if one or more of its stages has been cleaned up. - */ - def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized { - val skippedStageIds = jobIdToSkippedStageIds.getOrElse(jobId, Seq.empty) - val graphs = jobIdToStageIds.getOrElse(jobId, Seq.empty) - .flatMap { sid => stageIdToGraph.get(sid) } - // Mark any skipped stages as such - graphs.foreach { g => - val stageId = g.rootCluster.id.replaceAll(RDDOperationGraph.STAGE_CLUSTER_PREFIX, "").toInt - if (skippedStageIds.contains(stageId) && !g.rootCluster.name.contains("skipped")) { - g.rootCluster.setName(g.rootCluster.name + " (skipped)") - } - } - graphs - } - - /** Return the graph metadata for the given stage, or None if no such information exists. */ - def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = synchronized { - stageIdToGraph.get(stageId) - } - - /** On job start, construct a RDDOperationGraph for each stage in the job for display later. */ - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { - val jobId = jobStart.jobId - val stageInfos = jobStart.stageInfos - - jobIds += jobId - jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted - - stageInfos.foreach { stageInfo => - val stageId = stageInfo.stageId - stageIds += stageId - stageIdToJobId(stageId) = jobId - stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo, retainedNodes) - trimStagesIfNecessary() - } - - trimJobsIfNecessary() - } - - /** Keep track of stages that have completed. */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - val stageId = stageCompleted.stageInfo.stageId - if (stageIdToJobId.contains(stageId)) { - // Note: Only do this if the stage has not already been cleaned up - // Otherwise, we may never clean this stage from `completedStageIds` - completedStageIds += stageCompleted.stageInfo.stageId - } - } - - /** On job end, find all stages in this job that are skipped and mark them as such. */ - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobId = jobEnd.jobId - jobIdToStageIds.get(jobId).foreach { stageIds => - val skippedStageIds = stageIds.filter { sid => !completedStageIds.contains(sid) } - // Note: Only do this if the job has not already been cleaned up - // Otherwise, we may never clean this job from `jobIdToSkippedStageIds` - jobIdToSkippedStageIds(jobId) = skippedStageIds - } - } - - /** Clean metadata for old stages if we have exceeded the number to retain. */ - private def trimStagesIfNecessary(): Unit = { - if (stageIds.size >= retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stageIds.take(toRemove).foreach { id => cleanStage(id) } - stageIds.trimStart(toRemove) - } - } - - /** Clean metadata for old jobs if we have exceeded the number to retain. */ - private def trimJobsIfNecessary(): Unit = { - if (jobIds.size >= retainedJobs) { - val toRemove = math.max(retainedJobs / 10, 1) - jobIds.take(toRemove).foreach { id => cleanJob(id) } - jobIds.trimStart(toRemove) - } - } - - /** Clean metadata for the given stage, its job, and all other stages that belong to the job. */ - private[ui] def cleanStage(stageId: Int): Unit = { - completedStageIds.remove(stageId) - stageIdToGraph.remove(stageId) - stageIdToJobId.remove(stageId).foreach { jobId => cleanJob(jobId) } - } - - /** Clean metadata for the given job and all stages that belong to it. */ - private[ui] def cleanJob(jobId: Int): Unit = { - jobIdToSkippedStageIds.remove(jobId) - jobIdToStageIds.remove(jobId).foreach { stageIds => - stageIds.foreach { stageId => cleanStage(stageId) } - } - } - -} diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 25c4fff77e0ad..37b7d7269059f 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 3, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 162, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", @@ -23,14 +26,19 @@ "name" : "count at :17", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line19.$read$$iwC$$iwC$$iwC$$iwC.(:17)\n$line19.$read$$iwC$$iwC$$iwC.(:22)\n$line19.$read$$iwC$$iwC.(:24)\n$line19.$read$$iwC.(:26)\n$line19.$read.(:28)\n$line19.$read$.(:32)\n$line19.$read$.()\n$line19.$eval$.(:7)\n$line19.$eval$.()\n$line19.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 6, 5 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -49,14 +57,19 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 4338, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", @@ -75,5 +88,7 @@ "name" : "count at :15", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index b86ba1e65de12..2fd55666fa018 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -2,14 +2,18 @@ "status" : "FAILED", "stageId" : 2, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 7, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 7, "executorRunTime" : 278, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:06.296GMT", "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", "completionTime" : "2015-02-03T16:43:06.347GMT", + "failureReason" : "Job aborted due to stage failure: Task 3 in stage 2.0 failed 1 times, most recent failure: Lost task 3.0 in stage 2.0 (TID 19, localhost): java.lang.RuntimeException: got a 3, failing\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:18)\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:17)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:328)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1311)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:56)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:196)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -23,5 +27,7 @@ "name" : "count at :20", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 3, 2 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json index c108fa61a4318..2f275c7bfe2f4 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json index c108fa61a4318..2f275c7bfe2f4 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json index 3d7407004d262..71bf8706307c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 1, "name" : "count at :20", @@ -22,10 +25,13 @@ "numCompletedTasks" : 15, "numSkippedTasks" : 0, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 15, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 1 + "numFailedStages" : 1, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -36,8 +42,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json index 10c7e1c0b36fd..1eae5f3d5beb3 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json @@ -8,8 +8,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 6fb40f6f1713b..31093a661663b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -23,14 +26,15 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], "tasks" : { - "8" : { - "taskId" : 8, - "index" : 0, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.829GMT", - "duration" : 435, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 456, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -38,15 +42,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 435, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -66,17 +71,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 94000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 436, + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -84,15 +89,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -112,17 +118,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 88000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -130,15 +136,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -158,17 +165,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 76000, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 452, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -182,9 +189,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -203,8 +211,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 73000, "recordsWritten" : 0 } } @@ -214,7 +222,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -231,6 +239,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -255,12 +264,12 @@ } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -274,9 +283,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -295,18 +305,18 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "8" : { + "taskId" : 8, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -314,15 +324,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 435, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -342,7 +353,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 94000, "recordsWritten" : 0 } } @@ -352,7 +363,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", - "duration" : 435, + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -369,6 +380,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -399,12 +411,18 @@ "taskTime" : 3624, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 28000128, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 13180, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index f5a89a2107646..601d70695b17c 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -23,14 +26,15 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], "accumulatorUpdates" : [ ], "tasks" : { - "8" : { - "taskId" : 8, - "index" : 0, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.829GMT", - "duration" : 435, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 456, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -38,15 +42,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 435, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -66,17 +71,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 94000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 436, + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -84,15 +89,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -112,17 +118,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 88000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -130,15 +136,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -158,17 +165,17 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 76000, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 452, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -182,9 +189,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -203,8 +211,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 73000, "recordsWritten" : 0 } } @@ -214,7 +222,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -231,6 +239,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -255,12 +264,12 @@ } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -274,9 +283,10 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -295,18 +305,18 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "8" : { + "taskId" : 8, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", - "duration" : 434, + "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 454, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -314,15 +324,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 435, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -342,7 +353,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 94000, "recordsWritten" : 0 } } @@ -352,7 +363,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", - "duration" : 435, + "duration" : 450, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -369,6 +380,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 3500016, "recordsRead" : 0 @@ -399,12 +411,18 @@ "taskTime" : 3624, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 28000128, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 13180, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 6509df1508b30..1e6fb40d60284 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 3, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 162, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:07.191GMT", @@ -23,14 +26,51 @@ "name" : "count at :17", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line19.$read$$iwC$$iwC$$iwC$$iwC.(:17)\n$line19.$read$$iwC$$iwC$$iwC.(:22)\n$line19.$read$$iwC$$iwC.(:24)\n$line19.$read$$iwC.(:26)\n$line19.$read.(:28)\n$line19.$read$.(:32)\n$line19.$read$.()\n$line19.$eval$.(:7)\n$line19.$eval$.()\n$line19.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 6, 5 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } +}, { + "status" : "FAILED", + "stageId" : 2, + "attemptId" : 0, + "numTasks" : 8, + "numActiveTasks" : 0, + "numCompleteTasks" : 7, + "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 7, + "executorRunTime" : 278, + "executorCpuTime" : 0, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", + "failureReason" : "Job aborted due to stage failure: Task 3 in stage 2.0 failed 1 times, most recent failure: Lost task 3.0 in stage 2.0 (TID 19, localhost): java.lang.RuntimeException: got a 3, failing\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:18)\n\tat $line11.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(:17)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:328)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1311)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:910)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$4.apply(SparkContext.scala:1314)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:61)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:56)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:196)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)\n\tat java.lang.Thread.run(Thread.java:745)\n\nDriver stacktrace:", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "count at :20", + "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", + "schedulingPool" : "default", + "rddIds" : [ 3, 2 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 1, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 3476, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:05.829GMT", @@ -49,14 +89,19 @@ "name" : "map at :14", "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:271)\n$line10.$read$$iwC$$iwC$$iwC$$iwC.(:14)\n$line10.$read$$iwC$$iwC$$iwC.(:19)\n$line10.$read$$iwC$$iwC.(:21)\n$line10.$read$$iwC.(:23)\n$line10.$read.(:25)\n$line10.$read$.(:29)\n$line10.$read$.()\n$line10.$eval$.(:7)\n$line10.$eval$.()\n$line10.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } }, { "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 4338, "executorCpuTime" : 0, "submissionTime" : "2015-02-03T16:43:04.228GMT", @@ -75,31 +120,7 @@ "name" : "count at :15", "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", - "accumulatorUpdates" : [ ] -}, { - "status" : "FAILED", - "stageId" : 2, - "attemptId" : 0, - "numActiveTasks" : 0, - "numCompleteTasks" : 7, - "numFailedTasks" : 1, - "executorRunTime" : 278, - "executorCpuTime" : 0, - "submissionTime" : "2015-02-03T16:43:06.296GMT", - "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", - "completionTime" : "2015-02-03T16:43:06.347GMT", - "inputBytes" : 0, - "inputRecords" : 0, - "outputBytes" : 0, - "outputRecords" : 0, - "shuffleReadBytes" : 0, - "shuffleReadRecords" : 0, - "shuffleWriteBytes" : 0, - "shuffleWriteRecords" : 0, - "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0, - "name" : "count at :20", - "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", - "schedulingPool" : "default", - "accumulatorUpdates" : [ ] + "rddIds" : [ 0 ], + "accumulatorUpdates" : [ ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 8496863a93469..e6284ccf9b73d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 120, "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", @@ -23,9 +26,11 @@ "name" : "foreach at :15", "details" : "org.apache.spark.rdd.RDD.foreach(RDD.scala:765)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:483)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 0 ], "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", "value" : "5050" - } ] + } ], + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index 9b401b414f8d4..a15ee23523365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 49294, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 9, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60489, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", - "duration" : 73, + "duration" : 99, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", - "duration" : 75, + "duration" : 89, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "duration" : 138, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", - "duration" : 76, + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 2ebee66a6d7c2..f9182b1658334 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -25,6 +25,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -53,7 +54,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -75,6 +76,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -103,7 +105,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -125,6 +127,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -153,7 +156,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -175,6 +178,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -203,7 +207,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -225,6 +229,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -253,7 +258,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -275,6 +280,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -303,7 +309,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -325,6 +331,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -353,7 +360,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -375,6 +382,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 965a31a4104c3..76dd2f710b90f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -3,7 +3,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -25,6 +25,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -53,7 +54,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -75,6 +76,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -103,7 +105,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -125,6 +127,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -153,7 +156,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -175,6 +178,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -203,7 +207,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -225,6 +229,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -253,7 +258,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -275,6 +280,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -303,7 +309,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -325,6 +331,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -353,7 +360,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -375,6 +382,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 31132e156937c..6bdc10465d89e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -3,7 +3,7 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", - "duration" : 73, + "duration" : 99, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", - "duration" : 75, + "duration" : 89, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "duration" : 138, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", - "duration" : 76, + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 20, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 23, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.031GMT", - "duration" : 65, + "duration" : 84, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 24, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.098GMT", - "duration" : 43, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 25, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.103GMT", - "duration" : 49, + "duration" : 61, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 26, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.105GMT", - "duration" : 38, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 27, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.110GMT", - "duration" : 32, + "duration" : 41, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 28, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.113GMT", - "duration" : 29, + "duration" : 49, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 29, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.114GMT", - "duration" : 39, + "duration" : 52, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -903,7 +923,7 @@ "index" : 30, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.118GMT", - "duration" : 34, + "duration" : 62, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -920,6 +940,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -948,7 +969,7 @@ "index" : 31, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.127GMT", - "duration" : 24, + "duration" : 74, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -965,6 +986,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -993,7 +1015,7 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", - "duration" : 17, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1010,6 +1032,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1038,7 +1061,7 @@ "index" : 33, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.149GMT", - "duration" : 43, + "duration" : 58, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1055,6 +1078,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1083,7 +1107,7 @@ "index" : 34, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.156GMT", - "duration" : 27, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1100,6 +1124,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1128,7 +1153,7 @@ "index" : 35, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.161GMT", - "duration" : 35, + "duration" : 50, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1145,6 +1170,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1173,7 +1199,7 @@ "index" : 36, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.164GMT", - "duration" : 29, + "duration" : 40, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1190,6 +1216,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1218,7 +1245,7 @@ "index" : 37, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.165GMT", - "duration" : 32, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1235,6 +1262,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1263,7 +1291,7 @@ "index" : 38, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.166GMT", - "duration" : 31, + "duration" : 47, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1280,6 +1308,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1308,7 +1337,7 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", - "duration" : 17, + "duration" : 32, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1325,6 +1354,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1353,7 +1383,7 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", - "duration" : 14, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1370,6 +1400,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1398,7 +1429,7 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", - "duration" : 16, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1415,6 +1446,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1443,7 +1475,7 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", - "duration" : 17, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1460,6 +1492,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1488,7 +1521,7 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", - "duration" : 16, + "duration" : 39, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1505,6 +1538,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1533,7 +1567,7 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", - "duration" : 18, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1550,6 +1584,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1578,7 +1613,7 @@ "index" : 45, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.206GMT", - "duration" : 19, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1595,6 +1630,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1623,7 +1659,7 @@ "index" : 46, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.210GMT", - "duration" : 31, + "duration" : 43, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1640,6 +1676,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1668,7 +1705,7 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", - "duration" : 18, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1685,6 +1722,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1713,7 +1751,7 @@ "index" : 48, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.220GMT", - "duration" : 24, + "duration" : 30, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1730,6 +1768,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1758,7 +1797,7 @@ "index" : 49, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.223GMT", - "duration" : 23, + "duration" : 34, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1775,6 +1814,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1803,7 +1843,7 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", - "duration" : 18, + "duration" : 26, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1820,6 +1860,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1848,7 +1889,7 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", - "duration" : 17, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1865,6 +1906,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1893,7 +1935,7 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", - "duration" : 18, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1910,6 +1952,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1938,7 +1981,7 @@ "index" : 53, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", - "duration" : 18, + "duration" : 29, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -1955,6 +1998,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -1983,7 +2027,7 @@ "index" : 54, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", - "duration" : 18, + "duration" : 59, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2000,6 +2044,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2028,7 +2073,7 @@ "index" : 55, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.246GMT", - "duration" : 21, + "duration" : 30, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2045,6 +2090,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2073,7 +2119,7 @@ "index" : 56, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.249GMT", - "duration" : 20, + "duration" : 31, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2090,6 +2136,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2118,7 +2165,7 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", - "duration" : 16, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2135,6 +2182,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2163,7 +2211,7 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", - "duration" : 16, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2180,6 +2228,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -2208,7 +2257,7 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -2225,6 +2274,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 6af1cfbeb8f7e..bc1cd49909d31 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -3,7 +3,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -44,11 +45,11 @@ } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -56,15 +57,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -84,16 +86,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -101,15 +103,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -129,16 +132,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -146,17 +149,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -174,16 +178,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -197,9 +201,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -219,16 +224,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 3, + "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -242,9 +247,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -264,16 +270,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -281,17 +287,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -309,7 +316,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 3842811, "recordsWritten" : 10 } } @@ -318,7 +325,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -539,11 +551,11 @@ } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -551,17 +563,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -579,7 +592,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } @@ -588,7 +601,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -629,11 +643,11 @@ } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -641,17 +655,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 9, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -669,16 +684,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -686,15 +701,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -714,16 +730,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -731,15 +747,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -759,7 +776,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } @@ -768,7 +785,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -854,11 +873,11 @@ } } }, { - "taskId" : 13, - "index" : 13, + "taskId" : 15, + "index" : 15, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -866,7 +885,7 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, "executorCpuTime" : 0, @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -894,7 +914,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95004, + "writeTime" : 602780, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 6af1cfbeb8f7e..bc1cd49909d31 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -3,7 +3,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 351, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -44,11 +45,11 @@ } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 414, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -56,15 +57,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -84,16 +86,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", - "duration" : 350, + "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 421, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -101,15 +103,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 350, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -129,16 +132,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -146,17 +149,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -174,16 +178,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -197,9 +201,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -219,16 +224,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 3, + "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", - "duration" : 349, + "duration" : 423, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -242,9 +247,10 @@ "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -264,16 +270,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", - "duration" : 349, + "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 435, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -281,17 +287,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorDeserializeCpuTime" : 0, "executorRunTime" : 349, "executorCpuTime" : 0, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -309,7 +316,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 3842811, "recordsWritten" : 10 } } @@ -318,7 +325,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", - "duration" : 348, + "duration" : 419, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", - "duration" : 93, + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", - "duration" : 92, + "duration" : 105, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", - "duration" : 91, + "duration" : 123, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", - "duration" : 88, + "duration" : 96, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -539,11 +551,11 @@ } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -551,17 +563,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -579,7 +592,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } @@ -588,7 +601,7 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", - "duration" : 84, + "duration" : 98, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -629,11 +643,11 @@ } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", - "duration" : 84, + "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 101, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -641,17 +655,18 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 9, "executorDeserializeCpuTime" : 0, "executorRunTime" : 84, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -669,16 +684,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 90, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -686,15 +701,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -714,16 +730,16 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", - "duration" : 83, + "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 94, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -731,15 +747,16 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorDeserializeCpuTime" : 0, "executorRunTime" : 83, "executorCpuTime" : 0, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -759,7 +776,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } @@ -768,7 +785,7 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", - "duration" : 80, + "duration" : 88, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 60488, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", - "duration" : 77, + "duration" : 93, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -854,11 +873,11 @@ } } }, { - "taskId" : 13, - "index" : 13, + "taskId" : 15, + "index" : 15, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.924GMT", - "duration" : 76, + "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -866,7 +885,7 @@ "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 3, "executorDeserializeCpuTime" : 0, "executorRunTime" : 76, "executorCpuTime" : 0, @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -894,7 +914,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95004, + "writeTime" : 602780, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index c26daf4b8d7bd..09857cb401acd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -3,7 +3,7 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", - "duration" : 14, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -20,6 +20,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -48,7 +49,7 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", - "duration" : 16, + "duration" : 24, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -65,6 +66,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -93,7 +95,7 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", - "duration" : 16, + "duration" : 39, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -110,6 +112,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -138,7 +141,7 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", - "duration" : 16, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -155,6 +158,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -183,7 +187,7 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", - "duration" : 16, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -200,6 +204,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -228,7 +233,7 @@ "index" : 68, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.306GMT", - "duration" : 16, + "duration" : 22, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -245,6 +250,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -273,7 +279,7 @@ "index" : 86, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", - "duration" : 16, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -290,6 +296,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -318,7 +325,7 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", - "duration" : 17, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -335,6 +342,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -363,7 +371,7 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", - "duration" : 17, + "duration" : 32, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -380,6 +388,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -408,7 +417,7 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", - "duration" : 17, + "duration" : 42, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -425,6 +434,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -453,7 +463,7 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", - "duration" : 17, + "duration" : 21, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -470,6 +480,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -498,7 +509,7 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -515,6 +526,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -543,7 +555,7 @@ "index" : 63, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.276GMT", - "duration" : 17, + "duration" : 40, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -560,6 +572,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -588,7 +601,7 @@ "index" : 87, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", - "duration" : 17, + "duration" : 36, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -605,6 +618,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -633,7 +647,7 @@ "index" : 90, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.385GMT", - "duration" : 17, + "duration" : 23, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -650,6 +664,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -678,7 +693,7 @@ "index" : 99, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.426GMT", - "duration" : 17, + "duration" : 22, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -695,6 +710,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70565, "recordsRead" : 10000 @@ -723,7 +739,7 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", - "duration" : 18, + "duration" : 37, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -740,6 +756,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -768,7 +785,7 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", - "duration" : 18, + "duration" : 33, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -785,6 +802,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -813,7 +831,7 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", - "duration" : 18, + "duration" : 26, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -830,6 +848,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 @@ -858,7 +877,7 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", - "duration" : 18, + "duration" : 28, "executorId" : "driver", "host" : "localhost", "status" : "SUCCESS", @@ -875,6 +894,7 @@ "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 70564, "recordsRead" : 10000 diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 44b5f66efe339..9cdcef0746185 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -2,9 +2,12 @@ "status" : "COMPLETE", "stageId" : 0, "attemptId" : 0, + "numTasks" : 8, "numActiveTasks" : 0, "numCompleteTasks" : 8, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "executorRunTime" : 120, "executorCpuTime" : 0, "submissionTime" : "2015-03-16T19:25:36.103GMT", @@ -23,6 +26,7 @@ "name" : "foreach at :15", "details" : "org.apache.spark.rdd.RDD.foreach(RDD.scala:765)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:483)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", + "rddIds" : [ 0 ], "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", @@ -34,7 +38,7 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.515GMT", - "duration" : 15, + "duration" : 61, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -56,6 +60,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -85,7 +90,7 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 15, + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -107,6 +112,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -136,7 +142,7 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -158,6 +164,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -187,7 +194,7 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -209,6 +216,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -238,7 +246,7 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -260,6 +268,7 @@ "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -289,7 +298,7 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -311,6 +320,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -340,7 +350,7 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -362,6 +372,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -391,7 +402,7 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 15, + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -413,6 +424,7 @@ "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, "inputMetrics" : { "bytesRead" : 0, "recordsRead" : 0 @@ -443,12 +455,18 @@ "taskTime" : 418, "failedTasks" : 0, "succeededTasks" : 8, + "killedTasks" : 0, "inputBytes" : 0, + "inputRecords" : 0, "outputBytes" : 0, + "outputRecords" : 0, "shuffleRead" : 0, + "shuffleReadRecords" : 0, "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0 } - } + }, + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json index 3d7407004d262..71bf8706307c8 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 1, "name" : "count at :20", @@ -22,10 +25,13 @@ "numCompletedTasks" : 15, "numSkippedTasks" : 0, "numFailedTasks" : 1, + "numKilledTasks" : 0, + "numCompletedIndices" : 15, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 1 + "numFailedStages" : 1, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -36,8 +42,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json index 6a9bafd6b2191..b1ddd760c9714 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json @@ -8,10 +8,13 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } }, { "jobId" : 0, "name" : "count at :15", @@ -22,8 +25,11 @@ "numCompletedTasks" : 8, "numSkippedTasks" : 0, "numFailedTasks" : 0, + "numKilledTasks" : 0, + "numCompletedIndices" : 8, "numActiveStages" : 0, "numCompletedStages" : 1, "numSkippedStages" : 0, - "numFailedStages" : 0 + "numFailedStages" : 0, + "killedTasksSummary" : { } } ] diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 6a1abceaeb63c..d22a19e8af74a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.ZipInputStream import javax.servlet._ import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -44,8 +45,8 @@ import org.scalatest.selenium.WebBrowser import org.apache.spark._ import org.apache.spark.deploy.history.config._ +import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.util.{ResetSystemProperties, Utils} /** @@ -262,7 +263,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val badStageAttemptId = getContentAndCode("applications/local-1422981780767/stages/1/1") badStageAttemptId._1 should be (HttpServletResponse.SC_NOT_FOUND) - badStageAttemptId._3 should be (Some("unknown attempt for stage 1. Found attempts: [0]")) + badStageAttemptId._3 should be (Some("unknown attempt 1 for stage 1.")) val badStageId2 = getContentAndCode("applications/local-1422981780767/stages/flimflam") badStageId2._1 should be (HttpServletResponse.SC_NOT_FOUND) @@ -496,12 +497,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - def completedJobs(): Seq[JobUIData] = { - getAppUI.jobProgressListener.completedJobs + def completedJobs(): Seq[JobData] = { + getAppUI.store.jobsList(List(JobExecutionStatus.SUCCEEDED).asJava) } - def activeJobs(): Seq[JobUIData] = { - getAppUI.jobProgressListener.activeJobs.values.toSeq + def activeJobs(): Seq[JobData] = { + getAppUI.store.jobsList(List(JobExecutionStatus.RUNNING).asJava) } activeJobs() should have size 0 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ba082bc93dd42..88fe6bd70a14e 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -180,7 +180,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[StageDataWrapper](key(stages.head)) { stage => assert(stage.info.status === v1.StageStatus.ACTIVE) assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) - assert(stage.info.schedulingPool === "schedPool") + assert(stage.info.numTasks === stages.head.numTasks) } // Start tasks from stage 1 @@ -265,12 +265,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 - val reattempt = { - val orig = s1Tasks.head - // Task reattempts have a different ID, but the same index as the original. - new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, - s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) - } + val reattempt = newAttempt(s1Tasks.head, nextTaskId()) listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, reattempt)) @@ -288,7 +283,6 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](s1Tasks.head.taskId) { task => assert(task.info.status === s1Tasks.head.status) - assert(task.info.duration === Some(s1Tasks.head.duration)) assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) } @@ -297,8 +291,64 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(task.info.attempt === reattempt.attemptNumber) } + // Kill one task, restart it. + time += 1 + val killed = s1Tasks.drop(1).head + killed.finishTime = time + killed.failed = true + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskKilled("killed"), killed, null)) + + check[JobDataWrapper](1) { job => + assert(job.info.numKilledTasks === 1) + assert(job.info.killedTasksSummary === Map("killed" -> 1)) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numKilledTasks === 1) + assert(stage.info.killedTasksSummary === Map("killed" -> 1)) + } + + check[TaskDataWrapper](killed.taskId) { task => + assert(task.info.index === killed.index) + assert(task.info.errorMessage === Some("killed")) + } + + // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. + time += 1 + val denied = newAttempt(killed, nextTaskId()) + val denyReason = TaskCommitDenied(1, 1, 1) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + denied)) + + time += 1 + denied.finishTime = time + denied.failed = true + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", denyReason, denied, null)) + + check[JobDataWrapper](1) { job => + assert(job.info.numKilledTasks === 2) + assert(job.info.killedTasksSummary === Map("killed" -> 1, denyReason.toErrorString -> 1)) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numKilledTasks === 2) + assert(stage.info.killedTasksSummary === Map("killed" -> 1, denyReason.toErrorString -> 1)) + } + + check[TaskDataWrapper](denied.taskId) { task => + assert(task.info.index === killed.index) + assert(task.info.errorMessage === Some(denyReason.toErrorString)) + } + + // Start a new attempt. + val reattempt2 = newAttempt(denied, nextTaskId()) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt2)) + // Succeed all tasks in stage 1. - val pending = s1Tasks.drop(1) ++ Seq(reattempt) + val pending = s1Tasks.drop(2) ++ Seq(reattempt, reattempt2) val s1Metrics = TaskMetrics.empty s1Metrics.setExecutorCpuTime(2L) @@ -313,12 +363,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[JobDataWrapper](1) { job => assert(job.info.numFailedTasks === 1) + assert(job.info.numKilledTasks === 2) assert(job.info.numActiveTasks === 0) assert(job.info.numCompletedTasks === pending.size) } check[StageDataWrapper](key(stages.head)) { stage => assert(stage.info.numFailedTasks === 1) + assert(stage.info.numKilledTasks === 2) assert(stage.info.numActiveTasks === 0) assert(stage.info.numCompleteTasks === pending.size) } @@ -328,10 +380,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(wrapper.info.errorMessage === None) assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + assert(wrapper.info.duration === Some(task.duration)) } } - assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 3) // End stage 1. time += 1 @@ -404,6 +457,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numFailedTasks === s2Tasks.size) assert(stage.info.numActiveTasks === 0) assert(stage.info.numCompleteTasks === 0) + assert(stage.info.failureReason === stages.last.failureReason) } // - Re-submit stage 2, all tasks, and succeed them and the stage. @@ -804,6 +858,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { fn(value) } + private def newAttempt(orig: TaskInfo, nextId: Long): TaskInfo = { + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextId, orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + private case class RddBlock( rddId: Int, partId: Int, diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala deleted file mode 100644 index 82bd7c4ff6604..0000000000000 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1 - -import java.util.Date - -import scala.collection.mutable.LinkedHashMap - -import org.apache.spark.SparkFunSuite -import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} -import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} - -class AllStagesResourceSuite extends SparkFunSuite { - - def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { - val tasks = new LinkedHashMap[Long, TaskUIData] - taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => - tasks(idx.toLong) = TaskUIData( - new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false)) - } - - val stageUiData = new StageUIData() - stageUiData.taskData = tasks - val status = StageStatus.ACTIVE - val stageInfo = new StageInfo( - 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc") - val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) - - stageData.firstTaskLaunchedTime - } - - test("firstTaskLaunchedTime when there are no tasks") { - val result = getFirstTaskLaunchTime(Seq()) - assert(result == None) - } - - test("firstTaskLaunchedTime when there are tasks but none launched") { - val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) - assert(result == None) - } - - test("firstTaskLaunchedTime when there are tasks and some launched") { - val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) - assert(result == Some(new Date(1449255596000L))) - } - -} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 1c51c148ae61b..46932a02f1a1b 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore -import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} -import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.util.Utils class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -55,38 +55,40 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * This also runs a dummy stage to populate the page with useful content. */ private def renderStagePage(conf: SparkConf): Seq[Node] = { - val store = mock(classOf[AppStatusStore]) - when(store.executorSummary(anyString())).thenReturn(None) + val bus = new ReplayListenerBus() + val store = AppStatusStore.createLiveStore(conf, l => bus.addListener(l)) - val jobListener = new JobProgressListener(conf) - val graphListener = new RDDOperationGraphListener(conf) - val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) - val request = mock(classOf[HttpServletRequest]) - when(tab.conf).thenReturn(conf) - when(tab.progressListener).thenReturn(jobListener) - when(tab.operationGraphListener).thenReturn(graphListener) - when(tab.appName).thenReturn("testing") - when(tab.headerTabs).thenReturn(Seq.empty) - when(request.getParameter("id")).thenReturn("0") - when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab, store) + try { + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + when(tab.store).thenReturn(store) - // Simulate a stage in job progress listener - val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") - // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness - (1 to 2).foreach { - taskId => - val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) - val taskMetrics = TaskMetrics.empty - taskMetrics.incPeakExecutionMemory(peakExecutionMemory) - jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab, store) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, + false) + bus.postToAll(SparkListenerStageSubmitted(stageInfo)) + bus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) + val taskMetrics = TaskMetrics.empty + taskMetrics.incPeakExecutionMemory(peakExecutionMemory) + bus.postToAll(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) + } + bus.postToAll(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } finally { + store.close() } - jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) - page.render(request) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 267c8dc1bd750..6a6c37873e1c2 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -524,7 +524,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - test("stage & job retention") { + ignore("stage & job retention") { val conf = new SparkConf() .setMaster("local") .setAppName("test") @@ -670,34 +670,36 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString - assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + - "label="Stage 0";\n subgraph ")) - assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]")) - assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]")) - assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]")) - - val stage1 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString - assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + - "label="Stage 1";\n subgraph ")) - assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]")) - assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]")) - assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]")) - - val stage2 = Source.fromURL(sc.ui.get.webUrl + - "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString - assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + - "label="Stage 2";\n subgraph ")) - assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]")) + eventually(timeout(5 seconds), interval(100 milliseconds)) { + val stage0 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]")) + + val stage1 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]")) + + val stage2 = Source.fromURL(sc.ui.get.webUrl + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]")) + } } } diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala deleted file mode 100644 index 3fb78da0c7476..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.scope - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.scheduler._ - -/** - * Tests that this listener populates and cleans up its data structures properly. - */ -class RDDOperationGraphListenerSuite extends SparkFunSuite { - private var jobIdCounter = 0 - private var stageIdCounter = 0 - private val maxRetainedJobs = 10 - private val maxRetainedStages = 10 - private val conf = new SparkConf() - .set("spark.ui.retainedJobs", maxRetainedJobs.toString) - .set("spark.ui.retainedStages", maxRetainedStages.toString) - - test("run normal jobs") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - assert(listener.jobIdToStageIds.isEmpty) - assert(listener.jobIdToSkippedStageIds.isEmpty) - assert(listener.stageIdToJobId.isEmpty) - assert(listener.stageIdToGraph.isEmpty) - assert(listener.completedStageIds.isEmpty) - assert(listener.jobIds.isEmpty) - assert(listener.stageIds.isEmpty) - - // Run a few jobs, but not enough for clean up yet - (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages - (0 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish all 6 stages - (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs - - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToStageIds(startingJobId).size === 1) - assert(listener.jobIdToStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToStageIds(startingJobId + 2).size === 3) - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds.values.forall(_.isEmpty)) // no skipped stages - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToJobId(startingStageId) === startingJobId) - assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2) - assert(listener.stageIdToGraph.size === 6) - assert(listener.completedStageIds.size === 6) - assert(listener.jobIds.size === 3) - assert(listener.stageIds.size === 6) - } - - test("run jobs with skipped stages") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run a few jobs, but not enough for clean up yet - // Leave some stages unfinished so that they are marked as skipped - (1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages - (4 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish only last 2 stages - (0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs - - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds(startingJobId).size === 1) - assert(listener.jobIdToSkippedStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToSkippedStageIds(startingJobId + 2).size === 1) // 2 stages not skipped - assert(listener.completedStageIds.size === 2) - - // The rest should be the same as before - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToStageIds(startingJobId).size === 1) - assert(listener.jobIdToStageIds(startingJobId + 1).size === 2) - assert(listener.jobIdToStageIds(startingJobId + 2).size === 3) - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToJobId(startingStageId) === startingJobId) - assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1) - assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2) - assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2) - assert(listener.stageIdToGraph.size === 6) - assert(listener.jobIds.size === 3) - assert(listener.stageIds.size === 6) - } - - test("clean up metadata") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run many jobs and stages to trigger clean up - (1 to 10000).foreach { i => - // Note: this must be less than `maxRetainedStages` - val numStages = i % (maxRetainedStages - 2) + 1 - val startingStageIdForJob = stageIdCounter - val jobId = startJob(numStages, listener) - // End some, but not all, stages that belong to this job - // This is to ensure that we have both completed and skipped stages - (startingStageIdForJob until stageIdCounter) - .filter { i => i % 2 == 0 } - .foreach { i => endStage(i, listener) } - // End all jobs - endJob(jobId, listener) - } - - // Ensure we never exceed the max retained thresholds - assert(listener.jobIdToStageIds.size <= maxRetainedJobs) - assert(listener.jobIdToSkippedStageIds.size <= maxRetainedJobs) - assert(listener.stageIdToJobId.size <= maxRetainedStages) - assert(listener.stageIdToGraph.size <= maxRetainedStages) - assert(listener.completedStageIds.size <= maxRetainedStages) - assert(listener.jobIds.size <= maxRetainedJobs) - assert(listener.stageIds.size <= maxRetainedStages) - - // Also ensure we're actually populating these data structures - // Otherwise the previous group of asserts will be meaningless - assert(listener.jobIdToStageIds.nonEmpty) - assert(listener.jobIdToSkippedStageIds.nonEmpty) - assert(listener.stageIdToJobId.nonEmpty) - assert(listener.stageIdToGraph.nonEmpty) - assert(listener.completedStageIds.nonEmpty) - assert(listener.jobIds.nonEmpty) - assert(listener.stageIds.nonEmpty) - - // Ensure we clean up old jobs and stages, not arbitrary ones - assert(!listener.jobIdToStageIds.contains(startingJobId)) - assert(!listener.jobIdToSkippedStageIds.contains(startingJobId)) - assert(!listener.stageIdToJobId.contains(startingStageId)) - assert(!listener.stageIdToGraph.contains(startingStageId)) - assert(!listener.completedStageIds.contains(startingStageId)) - assert(!listener.stageIds.contains(startingStageId)) - assert(!listener.jobIds.contains(startingJobId)) - } - - test("fate sharing between jobs and stages") { - val startingJobId = jobIdCounter - val startingStageId = stageIdCounter - val listener = new RDDOperationGraphListener(conf) - - // Run 3 jobs and 8 stages, finishing all 3 jobs but only 2 stages - startJob(5, listener) - startJob(1, listener) - startJob(2, listener) - (0 until 8).foreach { i => startStage(i + startingStageId, listener) } - endStage(startingStageId + 3, listener) - endStage(startingStageId + 4, listener) - (0 until 3).foreach { i => endJob(i + startingJobId, listener) } - - // First, assert the old stuff - assert(listener.jobIdToStageIds.size === 3) - assert(listener.jobIdToSkippedStageIds.size === 3) - assert(listener.stageIdToJobId.size === 8) - assert(listener.stageIdToGraph.size === 8) - assert(listener.completedStageIds.size === 2) - - // Cleaning the third job should clean all of its stages - listener.cleanJob(startingJobId + 2) - assert(listener.jobIdToStageIds.size === 2) - assert(listener.jobIdToSkippedStageIds.size === 2) - assert(listener.stageIdToJobId.size === 6) - assert(listener.stageIdToGraph.size === 6) - assert(listener.completedStageIds.size === 2) - - // Cleaning one of the stages in the first job should clean that job and all of its stages - // Note that we still keep around the last stage because it belongs to a different job - listener.cleanStage(startingStageId) - assert(listener.jobIdToStageIds.size === 1) - assert(listener.jobIdToSkippedStageIds.size === 1) - assert(listener.stageIdToJobId.size === 1) - assert(listener.stageIdToGraph.size === 1) - assert(listener.completedStageIds.size === 0) - } - - /** Start a job with the specified number of stages. */ - private def startJob(numStages: Int, listener: RDDOperationGraphListener): Int = { - assert(numStages > 0, "I will not run a job with 0 stages for you.") - val stageInfos = (0 until numStages).map { _ => - val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d") - stageIdCounter += 1 - stageInfo - } - val jobId = jobIdCounter - listener.onJobStart(new SparkListenerJobStart(jobId, 0, stageInfos)) - // Also start all stages that belong to this job - stageInfos.map(_.stageId).foreach { sid => startStage(sid, listener) } - jobIdCounter += 1 - jobId - } - - /** Start the stage specified by the given ID. */ - private def startStage(stageId: Int, listener: RDDOperationGraphListener): Unit = { - val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d") - listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo)) - } - - /** Finish the stage specified by the given ID. */ - private def endStage(stageId: Int, listener: RDDOperationGraphListener): Unit = { - val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d") - listener.onStageCompleted(new SparkListenerStageCompleted(stageInfo)) - } - - /** Finish the job specified by the given ID. */ - private def endJob(jobId: Int, listener: RDDOperationGraphListener): Unit = { - listener.onJobEnd(new SparkListenerJobEnd(jobId, 0, JobSucceeded)) - } - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e6f136c7c8b0a..7f18b40f9d960 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -43,6 +43,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), From 0ffa7c488fa8156e2a1aa282e60b7c36b86d8af8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 14 Nov 2017 15:28:22 -0600 Subject: [PATCH 243/712] [SPARK-20652][SQL] Store SQL UI data in the new app status store. This change replaces the SQLListener with a new implementation that saves the data to the same store used by the SparkContext's status store. For that, the types used by the old SQLListener had to be updated a bit so that they're more serialization-friendly. The interface for getting data from the store was abstracted into a new class, SQLAppStatusStore (following the convention used in core). Another change is the way that the SQL UI hooks up into the core UI or the SHS. The old "SparkHistoryListenerFactory" was replaced with a new "AppStatePlugin" that more explicitly differentiates between the two use cases: processing events, and showing the UI. Both live apps and the SHS use this new API (previously, it was restricted to the SHS). Note on the above: this causes a slight change of behavior for live apps; the SQL tab will only show up after the first execution is started. The metrics gathering code was re-worked a bit so that the types used are less memory hungry and more serialization-friendly. This reduces memory usage when using in-memory stores, and reduces load times when using disk stores. Tested with existing and added unit tests. Note one unit test was disabled because it depends on SPARK-20653, which isn't in yet. Author: Marcelo Vanzin Closes #19681 from vanzin/SPARK-20652. --- .../scala/org/apache/spark/SparkContext.scala | 15 +- .../deploy/history/FsHistoryProvider.scala | 12 +- .../spark/scheduler/SparkListener.scala | 12 - .../apache/spark/status/AppStatusPlugin.scala | 71 +++ .../apache/spark/status/AppStatusStore.scala | 8 +- ...park.scheduler.SparkHistoryListenerFactory | 1 - .../org.apache.spark.status.AppStatusPlugin | 1 + .../org/apache/spark/sql/SparkSession.scala | 5 - .../sql/execution/ui/AllExecutionsPage.scala | 86 ++-- .../sql/execution/ui/ExecutionPage.scala | 60 ++- .../execution/ui/SQLAppStatusListener.scala | 366 ++++++++++++++++ .../sql/execution/ui/SQLAppStatusStore.scala | 179 ++++++++ .../spark/sql/execution/ui/SQLListener.scala | 403 +----------------- .../spark/sql/execution/ui/SQLTab.scala | 2 +- .../spark/sql/internal/SharedState.scala | 19 - .../execution/metric/SQLMetricsSuite.scala | 18 +- .../metric/SQLMetricsTestUtils.scala | 30 +- .../sql/execution/ui/SQLListenerSuite.scala | 340 ++++++++------- .../spark/sql/test/SharedSparkSession.scala | 1 - 19 files changed, 920 insertions(+), 709 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1d325e651b1d9..23fd54f59268a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,7 +54,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.{AppStatusPlugin, AppStatusStore} import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -246,6 +246,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def isStopped: Boolean = stopped.get() + private[spark] def statusStore: AppStatusStore = _statusStore + // An asynchronous listener bus for Spark events private[spark] def listenerBus: LiveListenerBus = _listenerBus @@ -455,9 +457,14 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - _ui.foreach(_.bind()) + _ui.foreach { ui => + // Load any plugins that might want to modify the UI. + AppStatusPlugin.loadPlugins().foreach(_.setupUI(ui)) + + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + ui.bind() + } _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a6dc53321d650..25f82b55f2003 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -41,7 +41,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ -import org.apache.spark.status.{AppStatusListener, AppStatusStore, AppStatusStoreMetadata, KVUtils} +import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI @@ -319,6 +319,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val _listener = new AppStatusListener(kvstore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(_listener) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupListeners(conf, kvstore, l => replayBus.addListener(l), false) + } Some(_listener) } else { None @@ -333,11 +336,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } try { - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, loadedUI.ui) - listeners.foreach(replayBus.addListener) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupUI(loadedUI.ui) } val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index b76e560669d59..3b677ca9657db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -167,18 +167,6 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @DeveloperApi case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent -/** - * Interface for creating history listeners defined in other modules like SQL, which are used to - * rebuild the history UI. - */ -private[spark] trait SparkHistoryListenerFactory { - /** - * Create listeners used to rebuild the history UI. - */ - def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] -} - - /** * Interface for listening to events from the Spark scheduler. Most applications should probably * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala new file mode 100644 index 0000000000000..69ca02ec76293 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.KVStore + +/** + * An interface that defines plugins for collecting and storing application state. + * + * The plugin implementations are invoked for both live and replayed applications. For live + * applications, it's recommended that plugins defer creation of UI tabs until there's actual + * data to be shown. + */ +private[spark] trait AppStatusPlugin { + + /** + * Install listeners to collect data about the running application and populate the given + * store. + * + * @param conf The Spark configuration. + * @param store The KVStore where to keep application data. + * @param addListenerFn Function to register listeners with a bus. + * @param live Whether this is a live application (or an application being replayed by the + * HistoryServer). + */ + def setupListeners( + conf: SparkConf, + store: KVStore, + addListenerFn: SparkListener => Unit, + live: Boolean): Unit + + /** + * Install any needed extensions (tabs, pages, etc) to a Spark UI. The plugin can detect whether + * the app is live or replayed by looking at the UI's SparkContext field `sc`. + * + * @param ui The Spark UI instance for the application. + */ + def setupUI(ui: SparkUI): Unit + +} + +private[spark] object AppStatusPlugin { + + def loadPlugins(): Iterable[AppStatusPlugin] = { + ServiceLoader.load(classOf[AppStatusPlugin], Utils.getContextOrSparkClassLoader).asScala + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 9b42f55605755..d0615e5dd0223 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** * A wrapper around a KVStore that provides methods for accessing the API data stored within. */ -private[spark] class AppStatusStore(store: KVStore) { +private[spark] class AppStatusStore(val store: KVStore) { def applicationInfo(): v1.ApplicationInfo = { store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info @@ -338,9 +338,11 @@ private[spark] object AppStatusStore { */ def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() - val stateStore = new AppStatusStore(store) addListenerFn(new AppStatusListener(store, conf, true)) - stateStore + AppStatusPlugin.loadPlugins().foreach { p => + p.setupListeners(conf, store, addListenerFn, true) + } + new AppStatusStore(store) } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory deleted file mode 100644 index 507100be90967..0000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin new file mode 100644 index 0000000000000..ac6d7f6962f85 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLAppStatusPlugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2821f5ee7feee..272eb844226d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation @@ -957,7 +956,6 @@ object SparkSession { sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { defaultSession.set(null) - sqlListener.set(null) } }) } @@ -1026,9 +1024,6 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) - /** A global SQL listener used for the SQL UI. */ - private[sql] val sqlListener = new AtomicReference[SQLListener]() - //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index f9c69864a3361..7019d98e1619f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -24,34 +24,54 @@ import scala.xml.{Node, NodeSeq} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { - private val listener = parent.listener + private val sqlStore = parent.sqlStore override def render(request: HttpServletRequest): Seq[Node] = { val currentTime = System.currentTimeMillis() - val content = listener.synchronized { + val running = new mutable.ArrayBuffer[SQLExecutionUIData]() + val completed = new mutable.ArrayBuffer[SQLExecutionUIData]() + val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() + + sqlStore.executionsList().foreach { e => + val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } + if (isRunning) { + running += e + } else if (isFailed) { + failed += e + } else { + completed += e + } + } + + val content = { val _content = mutable.ListBuffer[Node]() - if (listener.getRunningExecutions.nonEmpty) { + + if (running.nonEmpty) { _content ++= new RunningExecutionTable( - parent, s"Running Queries (${listener.getRunningExecutions.size})", currentTime, - listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Running Queries (${running.size})", currentTime, + running.sortBy(_.submissionTime).reverse).toNodeSeq } - if (listener.getCompletedExecutions.nonEmpty) { + + if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( - parent, s"Completed Queries (${listener.getCompletedExecutions.size})", currentTime, - listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Completed Queries (${completed.size})", currentTime, + completed.sortBy(_.submissionTime).reverse).toNodeSeq } - if (listener.getFailedExecutions.nonEmpty) { + + if (failed.nonEmpty) { _content ++= new FailedExecutionTable( - parent, s"Failed Queries (${listener.getFailedExecutions.size})", currentTime, - listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + parent, s"Failed Queries (${failed.size})", currentTime, + failed.sortBy(_.submissionTime).reverse).toNodeSeq } _content } @@ -65,26 +85,26 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
      { - if (listener.getRunningExecutions.nonEmpty) { + if (running.nonEmpty) {
    • Running Queries: - {listener.getRunningExecutions.size} + {running.size}
    • } } { - if (listener.getCompletedExecutions.nonEmpty) { + if (completed.nonEmpty) {
    • Completed Queries: - {listener.getCompletedExecutions.size} + {completed.size}
    • } } { - if (listener.getFailedExecutions.nonEmpty) { + if (failed.nonEmpty) {
    • Failed Queries: - {listener.getFailedExecutions.size} + {failed.size}
    • } } @@ -114,23 +134,19 @@ private[ui] abstract class ExecutionTable( protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime - val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime + val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - + submissionTime - val runningJobs = executionUIData.runningJobs.map { jobId => - - [{jobId.toString}] - - } - val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId => - - [{jobId.toString}] - - } - val failedJobs = executionUIData.failedJobs.sorted.map { jobId => - - [{jobId.toString}] - + def jobLinks(status: JobExecutionStatus): Seq[Node] = { + executionUIData.jobs.flatMap { case (jobId, jobStatus) => + if (jobStatus == status) { + [{jobId.toString}] + } else { + None + } + }.toSeq } + {executionUIData.executionId.toString} @@ -146,17 +162,17 @@ private[ui] abstract class ExecutionTable( {if (showRunningJobs) { - {runningJobs} + {jobLinks(JobExecutionStatus.RUNNING)} }} {if (showSucceededJobs) { - {succeededJobs} + {jobLinks(JobExecutionStatus.SUCCEEDED)} }} {if (showFailedJobs) { - {failedJobs} + {jobLinks(JobExecutionStatus.FAILED)} }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 460fc946c3e6f..f29e135ac357f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -21,24 +21,42 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { - private val listener = parent.listener + private val sqlStore = parent.sqlStore - override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { + override def render(request: HttpServletRequest): Seq[Node] = { // stripXSS is called first to remove suspicious characters used in XSS attacks val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id")) require(parameterExecutionId != null && parameterExecutionId.nonEmpty, "Missing execution id parameter") val executionId = parameterExecutionId.toLong - val content = listener.getExecution(executionId).map { executionUIData => + val content = sqlStore.execution(executionId).map { executionUIData => val currentTime = System.currentTimeMillis() - val duration = - executionUIData.completionTime.getOrElse(currentTime) - executionUIData.submissionTime + val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - + executionUIData.submissionTime + + def jobLinks(status: JobExecutionStatus, label: String): Seq[Node] = { + val jobs = executionUIData.jobs.flatMap { case (jobId, jobStatus) => + if (jobStatus == status) Some(jobId) else None + } + if (jobs.nonEmpty) { +
    • + {label} + {jobs.toSeq.sorted.map { jobId => + {jobId.toString}  + }} +
    • + } else { + Nil + } + } + val summary =
      @@ -49,37 +67,17 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    • Duration: {UIUtils.formatDuration(duration)}
    • - {if (executionUIData.runningJobs.nonEmpty) { -
    • - Running Jobs: - {executionUIData.runningJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} - {if (executionUIData.succeededJobs.nonEmpty) { -
    • - Succeeded Jobs: - {executionUIData.succeededJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} - {if (executionUIData.failedJobs.nonEmpty) { -
    • - Failed Jobs: - {executionUIData.failedJobs.sorted.map { jobId => - {jobId.toString}  - }} -
    • - }} + {jobLinks(JobExecutionStatus.RUNNING, "Running Jobs:")} + {jobLinks(JobExecutionStatus.SUCCEEDED, "Succeeded Jobs:")} + {jobLinks(JobExecutionStatus.FAILED, "Failed Jobs:")}
    - val metrics = listener.getExecutionMetrics(executionId) + val metrics = sqlStore.executionMetrics(executionId) + val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, executionUIData.physicalPlanGraph) ++ + planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for Plan {executionId}
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala new file mode 100644 index 0000000000000..43cec4807ae4d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.ui + +import java.util.Date +import java.util.concurrent.ConcurrentHashMap +import java.util.function.Function + +import scala.collection.JavaConverters._ + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.status.LiveEntity +import org.apache.spark.status.config._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +private[sql] class SQLAppStatusListener( + conf: SparkConf, + kvstore: KVStore, + live: Boolean, + ui: Option[SparkUI] = None) + extends SparkListener with Logging { + + // How often to flush intermediate state of a live execution to the store. When replaying logs, + // never flush (only do the very last write). + private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + + // Live tracked data is needed by the SQL status store to calculate metrics for in-flight + // executions; that means arbitrary threads may be querying these maps, so they need to be + // thread-safe. + private val liveExecutions = new ConcurrentHashMap[Long, LiveExecutionData]() + private val stageMetrics = new ConcurrentHashMap[Int, LiveStageMetrics]() + + private var uiInitialized = false + + override def onJobStart(event: SparkListenerJobStart): Unit = { + val executionIdString = event.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + + val executionId = executionIdString.toLong + val jobId = event.jobId + val exec = getOrCreateExecution(executionId) + + // Record the accumulator IDs for the stages of this job, so that the code that keeps + // track of the metrics knows which accumulators to look at. + val accumIds = exec.metrics.map(_.accumulatorId).sorted.toList + event.stageIds.foreach { id => + stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds.toArray, new ConcurrentHashMap())) + } + + exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) + exec.stages = event.stageIds.toSet + update(exec) + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + if (!isSQLStage(event.stageInfo.stageId)) { + return + } + + // Reset the metrics tracking object for the new attempt. + Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => + metrics.taskMetrics.clear() + metrics.attemptId = event.stageInfo.attemptId + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveExecutions.values().asScala.foreach { exec => + if (exec.jobs.contains(event.jobId)) { + val result = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case _ => JobExecutionStatus.FAILED + } + exec.jobs = exec.jobs + (event.jobId -> result) + exec.endEvents += 1 + update(exec) + } + } + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) => + updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + if (!isSQLStage(event.stageId)) { + return + } + + val info = event.taskInfo + // SPARK-20342. If processing events from a live application, use the task metrics info to + // work around a race in the DAGScheduler. The metrics info does not contain accumulator info + // when reading event logs in the SHS, so we have to rely on the accumulator in that case. + val accums = if (live && event.taskMetrics != null) { + event.taskMetrics.externalAccums.flatMap { a => + // This call may fail if the accumulator is gc'ed, so account for that. + try { + Some(a.toInfo(Some(a.value), None)) + } catch { + case _: IllegalAccessError => None + } + } + } else { + info.accumulables + } + updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums, + info.successful) + } + + def liveExecutionMetrics(executionId: Long): Option[Map[Long, String]] = { + Option(liveExecutions.get(executionId)).map { exec => + if (exec.metricsValues != null) { + exec.metricsValues + } else { + aggregateMetrics(exec) + } + } + } + + private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { + val metricIds = exec.metrics.map(_.accumulatorId).sorted + val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val metrics = exec.stages.toSeq + .flatMap { stageId => Option(stageMetrics.get(stageId)) } + .flatMap(_.taskMetrics.values().asScala) + .flatMap { metrics => metrics.ids.zip(metrics.values) } + + val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) + .filter { case (id, _) => metricIds.contains(id) } + .groupBy(_._1) + .map { case (id, values) => + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + } + + // Check the execution again for whether the aggregated metrics data has been calculated. + // This can happen if the UI is requesting this data, and the onExecutionEnd handler is + // running at the same time. The metrics calculcated for the UI can be innacurate in that + // case, since the onExecutionEnd handler will clean up tracked stage metrics. + if (exec.metricsValues != null) { + exec.metricsValues + } else { + aggregatedMetrics + } + } + + private def updateStageMetrics( + stageId: Int, + attemptId: Int, + taskId: Long, + accumUpdates: Seq[AccumulableInfo], + succeeded: Boolean): Unit = { + Option(stageMetrics.get(stageId)).foreach { metrics => + if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) { + return + } + + val oldTaskMetrics = metrics.taskMetrics.get(taskId) + if (oldTaskMetrics != null && oldTaskMetrics.succeeded) { + return + } + + val updates = accumUpdates + .filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) } + .sortBy(_.id) + + if (updates.isEmpty) { + return + } + + val ids = new Array[Long](updates.size) + val values = new Array[Long](updates.size) + updates.zipWithIndex.foreach { case (acc, idx) => + ids(idx) = acc.id + // In a live application, accumulators have Long values, but when reading from event + // logs, they have String values. For now, assume all accumulators are Long and covert + // accordingly. + values(idx) = acc.update.get match { + case s: String => s.toLong + case l: Long => l + case o => throw new IllegalArgumentException(s"Unexpected: $o") + } + } + + // TODO: storing metrics by task ID can cause metrics for the same task index to be + // counted multiple times, for example due to speculation or re-attempts. + metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded)) + } + } + + private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { + // Install the SQL tab in a live app if it hasn't been initialized yet. + if (!uiInitialized) { + ui.foreach { _ui => + new SQLTab(new SQLAppStatusStore(kvstore, Some(this)), _ui) + } + uiInitialized = true + } + + val SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) = event + + def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { + nodes.map { + case cluster: SparkPlanGraphCluster => + val storedCluster = new SparkPlanGraphClusterWrapper( + cluster.id, + cluster.name, + cluster.desc, + toStoredNodes(cluster.nodes), + cluster.metrics) + new SparkPlanGraphNodeWrapper(null, storedCluster) + + case node => + new SparkPlanGraphNodeWrapper(node, null) + } + } + + val planGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = planGraph.allNodes.flatMap { node => + node.metrics.map { metric => (metric.accumulatorId, metric) } + }.toMap.values.toList + + val graphToStore = new SparkPlanGraphWrapper( + executionId, + toStoredNodes(planGraph.nodes), + planGraph.edges) + kvstore.write(graphToStore) + + val exec = getOrCreateExecution(executionId) + exec.description = description + exec.details = details + exec.physicalPlanDescription = physicalPlanDescription + exec.metrics = sqlPlanMetrics + exec.submissionTime = time + update(exec) + } + + private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = { + val SparkListenerSQLExecutionEnd(executionId, time) = event + Option(liveExecutions.get(executionId)).foreach { exec => + exec.metricsValues = aggregateMetrics(exec) + exec.completionTime = Some(new Date(time)) + exec.endEvents += 1 + update(exec) + + // Remove stale LiveStageMetrics objects for stages that are not active anymore. + val activeStages = liveExecutions.values().asScala.flatMap { other => + if (other != exec) other.stages else Nil + }.toSet + stageMetrics.keySet().asScala + .filter(!activeStages.contains(_)) + .foreach(stageMetrics.remove) + } + } + + private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { + val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event + Option(liveExecutions.get(executionId)).foreach { exec => + exec.driverAccumUpdates = accumUpdates.toMap + update(exec) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: SparkListenerSQLExecutionStart => onExecutionStart(e) + case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e) + case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e) + case _ => // Ignore + } + + private def getOrCreateExecution(executionId: Long): LiveExecutionData = { + liveExecutions.computeIfAbsent(executionId, + new Function[Long, LiveExecutionData]() { + override def apply(key: Long): LiveExecutionData = new LiveExecutionData(executionId) + }) + } + + private def update(exec: LiveExecutionData): Unit = { + val now = System.nanoTime() + if (exec.endEvents >= exec.jobs.size + 1) { + exec.write(kvstore, now) + liveExecutions.remove(exec.executionId) + } else if (liveUpdatePeriodNs >= 0) { + if (now - exec.lastWriteTime > liveUpdatePeriodNs) { + exec.write(kvstore, now) + } + } + } + + private def isSQLStage(stageId: Int): Boolean = { + liveExecutions.values().asScala.exists { exec => + exec.stages.contains(stageId) + } + } + +} + +private class LiveExecutionData(val executionId: Long) extends LiveEntity { + + var description: String = null + var details: String = null + var physicalPlanDescription: String = null + var metrics = Seq[SQLPlanMetric]() + var submissionTime = -1L + var completionTime: Option[Date] = None + + var jobs = Map[Int, JobExecutionStatus]() + var stages = Set[Int]() + var driverAccumUpdates = Map[Long, Long]() + + @volatile var metricsValues: Map[Long, String] = null + + // Just in case job end and execution end arrive out of order, keep track of how many + // end events arrived so that the listener can stop tracking the execution. + var endEvents = 0 + + override protected def doUpdate(): Any = { + new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + metrics, + submissionTime, + completionTime, + jobs, + stages, + metricsValues) + } + +} + +private class LiveStageMetrics( + val stageId: Int, + var attemptId: Int, + val accumulatorIds: Array[Long], + val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) + +private[sql] class LiveTaskMetrics( + val ids: Array[Long], + val values: Array[Long], + val succeeded: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala new file mode 100644 index 0000000000000..586d3ae411c74 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.lang.{Long => JLong} +import java.util.Date + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.status.AppStatusPlugin +import org.apache.spark.status.KVUtils.KVIndexParam +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.KVStore + +/** + * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's + * no state kept in this class, so it's ok to have multiple instances of it in an application. + */ +private[sql] class SQLAppStatusStore( + store: KVStore, + listener: Option[SQLAppStatusListener] = None) { + + def executionsList(): Seq[SQLExecutionUIData] = { + store.view(classOf[SQLExecutionUIData]).asScala.toSeq + } + + def execution(executionId: Long): Option[SQLExecutionUIData] = { + try { + Some(store.read(classOf[SQLExecutionUIData], executionId)) + } catch { + case _: NoSuchElementException => None + } + } + + def executionsCount(): Long = { + store.count(classOf[SQLExecutionUIData]) + } + + def executionMetrics(executionId: Long): Map[Long, String] = { + def metricsFromStore(): Option[Map[Long, String]] = { + val exec = store.read(classOf[SQLExecutionUIData], executionId) + Option(exec.metricValues) + } + + metricsFromStore() + .orElse(listener.flatMap(_.liveExecutionMetrics(executionId))) + // Try a second time in case the execution finished while this method is trying to + // get the metrics. + .orElse(metricsFromStore()) + .getOrElse(Map()) + } + + def planGraph(executionId: Long): SparkPlanGraph = { + store.read(classOf[SparkPlanGraphWrapper], executionId).toSparkPlanGraph() + } + +} + +/** + * An AppStatusPlugin for handling the SQL UI and listeners. + */ +private[sql] class SQLAppStatusPlugin extends AppStatusPlugin { + + override def setupListeners( + conf: SparkConf, + store: KVStore, + addListenerFn: SparkListener => Unit, + live: Boolean): Unit = { + // For live applications, the listener is installed in [[setupUI]]. This also avoids adding + // the listener when the UI is disabled. Force installation during testing, though. + if (!live || Utils.isTesting) { + val listener = new SQLAppStatusListener(conf, store, live, None) + addListenerFn(listener) + } + } + + override def setupUI(ui: SparkUI): Unit = { + ui.sc match { + case Some(sc) => + // If this is a live application, then install a listener that will enable the SQL + // tab as soon as there's a SQL event posted to the bus. + val listener = new SQLAppStatusListener(sc.conf, ui.store.store, true, Some(ui)) + sc.listenerBus.addToStatusQueue(listener) + + case _ => + // For a replayed application, only add the tab if the store already contains SQL data. + val sqlStore = new SQLAppStatusStore(ui.store.store) + if (sqlStore.executionsCount() > 0) { + new SQLTab(sqlStore, ui) + } + } + } + +} + +private[sql] class SQLExecutionUIData( + @KVIndexParam val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val metrics: Seq[SQLPlanMetric], + val submissionTime: Long, + val completionTime: Option[Date], + @JsonDeserialize(keyAs = classOf[Integer]) + val jobs: Map[Int, JobExecutionStatus], + @JsonDeserialize(contentAs = classOf[Integer]) + val stages: Set[Int], + /** + * This field is only populated after the execution is finished; it will be null while the + * execution is still running. During execution, aggregate metrics need to be retrieved + * from the SQL listener instance. + */ + @JsonDeserialize(keyAs = classOf[JLong]) + val metricValues: Map[Long, String] + ) + +private[sql] class SparkPlanGraphWrapper( + @KVIndexParam val executionId: Long, + val nodes: Seq[SparkPlanGraphNodeWrapper], + val edges: Seq[SparkPlanGraphEdge]) { + + def toSparkPlanGraph(): SparkPlanGraph = { + SparkPlanGraph(nodes.map(_.toSparkPlanGraphNode()), edges) + } + +} + +private[sql] class SparkPlanGraphClusterWrapper( + val id: Long, + val name: String, + val desc: String, + val nodes: Seq[SparkPlanGraphNodeWrapper], + val metrics: Seq[SQLPlanMetric]) { + + def toSparkPlanGraphCluster(): SparkPlanGraphCluster = { + new SparkPlanGraphCluster(id, name, desc, + new ArrayBuffer() ++ nodes.map(_.toSparkPlanGraphNode()), + metrics) + } + +} + +/** Only one of the values should be set. */ +private[sql] class SparkPlanGraphNodeWrapper( + val node: SparkPlanGraphNode, + val cluster: SparkPlanGraphClusterWrapper) { + + def toSparkPlanGraphNode(): SparkPlanGraphNode = { + assert(node == null ^ cluster == null, "One and only of of nore or cluster must be set.") + if (node != null) node else cluster.toSparkPlanGraphCluster() + } + +} + +private[sql] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricType: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 8c27af374febd..b58b8c6d45e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.ui -import scala.collection.mutable - import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.`type`.TypeFactory import com.fasterxml.jackson.databind.annotation.JsonDeserialize import com.fasterxml.jackson.databind.util.Converter -import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric._ -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.AccumulatorContext @DeveloperApi case class SparkListenerSQLExecutionStart( @@ -89,398 +83,3 @@ private class LongLongTupleConverter extends Converter[(Object, Object), (Long, typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) } } - -class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { - - override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { - List(new SQLHistoryListener(conf, sparkUI)) - } -} - -class SQLListener(conf: SparkConf) extends SparkListener with Logging { - - private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) - - private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() - - // Old data in the following fields must be removed in "trimExecutionsIfNecessary". - // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data - private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() - - /** - * Maintain the relation between job id and execution id so that we can get the execution id in - * the "onJobEnd" method. - */ - private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() - - private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() - - private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - - private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - - def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { - _executionIdToData.toMap - } - - def jobIdToExecutionId: Map[Long, Long] = synchronized { - _jobIdToExecutionId.toMap - } - - def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { - _stageIdToStageMetrics.toMap - } - - private def trimExecutionsIfNecessary( - executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { - if (executions.size > retainedExecutions) { - val toRemove = math.max(retainedExecutions / 10, 1) - executions.take(toRemove).foreach { execution => - for (executionUIData <- _executionIdToData.remove(execution.executionId)) { - for (jobId <- executionUIData.jobs.keys) { - _jobIdToExecutionId.remove(jobId) - } - for (stageId <- executionUIData.stages) { - _stageIdToStageMetrics.remove(stageId) - } - } - } - executions.trimStart(toRemove) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) - if (executionIdString == null) { - // This is not a job created by SQL - return - } - val executionId = executionIdString.toLong - val jobId = jobStart.jobId - val stageIds = jobStart.stageIds - - synchronized { - activeExecutions.get(executionId).foreach { executionUIData => - executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING - executionUIData.stages ++= stageIds - stageIds.foreach(stageId => - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) - _jobIdToExecutionId(jobId) = executionId - } - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobId = jobEnd.jobId - for (executionId <- _jobIdToExecutionId.get(jobId); - executionUIData <- _executionIdToData.get(executionId)) { - jobEnd.jobResult match { - case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED - case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED - } - if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { - // We are the last job of this execution, so mark the execution as finished. Note that - // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` - // since these are called on different threads. - markExecutionFinished(executionId) - } - } - } - - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false) - } - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val stageId = stageSubmitted.stageInfo.stageId - val stageAttemptId = stageSubmitted.stageInfo.attemptId - // Always override metrics for old stage attempt - if (_stageIdToStageMetrics.contains(stageId)) { - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) - } else { - // If a stage belongs to some SQL execution, its stageId will be put in "onJobStart". - // Since "_stageIdToStageMetrics" doesn't contain it, it must not belong to any SQL execution. - // So we can ignore it. Otherwise, this may lead to memory leaks (SPARK-11126). - } - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - if (taskEnd.taskMetrics != null) { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskMetrics.externalAccums.map(a => a.toInfo(Some(a.value), None)), - finishTask = true) - } - } - - /** - * Update the accumulator values of a task with the latest metrics for this task. This is called - * every time we receive an executor heartbeat or when a task finishes. - */ - protected def updateTaskAccumulatorValues( - taskId: Long, - stageId: Int, - stageAttemptID: Int, - _accumulatorUpdates: Seq[AccumulableInfo], - finishTask: Boolean): Unit = { - val accumulatorUpdates = - _accumulatorUpdates.filter(_.update.isDefined).map(accum => (accum.id, accum.update.get)) - - _stageIdToStageMetrics.get(stageId) match { - case Some(stageMetrics) => - if (stageAttemptID < stageMetrics.stageAttemptId) { - // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. - } else if (stageAttemptID > stageMetrics.stageAttemptId) { - logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + - s"what we have seen (${stageMetrics.stageAttemptId})") - } else { - // TODO We don't know the attemptId. Currently, what we can do is overriding the - // accumulator updates. However, if there are two same task are running, such as - // speculation, the accumulator updates will be overriding by different task attempts, - // the results will be weird. - stageMetrics.taskIdToMetricUpdates.get(taskId) match { - case Some(taskMetrics) => - if (finishTask) { - taskMetrics.finished = true - taskMetrics.accumulatorUpdates = accumulatorUpdates - } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = accumulatorUpdates - } else { - // If a task is finished, we should not override with accumulator updates from - // heartbeat reports - } - case None => - stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - finished = finishTask, accumulatorUpdates) - } - } - case None => - // This execution and its stage have been dropped - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) => - val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.allNodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - val executionUIData = new SQLExecutionUIData( - executionId, - description, - details, - physicalPlanDescription, - physicalPlanGraph, - sqlPlanMetrics.toMap, - time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. - } - } - } - case SparkListenerDriverAccumUpdates(executionId, accumUpdates) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - for ((accId, accValue) <- accumUpdates) { - executionUIData.driverAccumUpdates(accId) = accValue - } - } - } - case _ => // Ignore - } - - private def markExecutionFinished(executionId: Long): Unit = { - activeExecutions.remove(executionId).foreach { executionUIData => - if (executionUIData.isFailed) { - failedExecutions += executionUIData - trimExecutionsIfNecessary(failedExecutions) - } else { - completedExecutions += executionUIData - trimExecutionsIfNecessary(completedExecutions) - } - } - } - - def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { - activeExecutions.values.toSeq - } - - def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { - failedExecutions - } - - def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { - completedExecutions - } - - def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { - _executionIdToData.get(executionId) - } - - /** - * Get all accumulator updates from all tasks which belong to this execution and merge them. - */ - def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized { - _executionIdToData.get(executionId) match { - case Some(executionUIData) => - val accumulatorUpdates = { - for (stageId <- executionUIData.stages; - stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; - taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; - accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { - (accumulatorUpdate._1, accumulatorUpdate._2) - } - } - - val driverUpdates = executionUIData.driverAccumUpdates.toSeq - val totalUpdates = (accumulatorUpdates ++ driverUpdates).filter { - case (id, _) => executionUIData.accumulatorMetrics.contains(id) - } - mergeAccumulatorUpdates(totalUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).metricType) - case None => - // This execution has been dropped - Map.empty - } - } - - private def mergeAccumulatorUpdates( - accumulatorUpdates: Seq[(Long, Any)], - metricTypeFunc: Long => String): Map[Long, String] = { - accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => - val metricType = metricTypeFunc(accumulatorId) - accumulatorId -> - SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long])) - } - } - -} - - -/** - * A [[SQLListener]] for rendering the SQL UI in the history server. - */ -class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) - extends SQLListener(conf) { - - private var sqlTabAttached = false - - override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { - // Do nothing; these events are not logged - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.flatMap { a => - // Filter out accumulators that are not SQL metrics - // For now we assume all SQL metrics are Long's that have been JSON serialized as String's - if (a.metadata == Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) { - val newValue = a.update.map(_.toString.toLong).getOrElse(0L) - Some(a.copy(update = Some(newValue))) - } else { - None - } - }, - finishTask = true) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case _: SparkListenerSQLExecutionStart => - if (!sqlTabAttached) { - new SQLTab(this, sparkUI) - sqlTabAttached = true - } - super.onOtherEvent(event) - case _ => super.onOtherEvent(event) - } -} - -/** - * Represent all necessary data for an execution that will be used in Web UI. - */ -private[ui] class SQLExecutionUIData( - val executionId: Long, - val description: String, - val details: String, - val physicalPlanDescription: String, - val physicalPlanGraph: SparkPlanGraph, - val accumulatorMetrics: Map[Long, SQLPlanMetric], - val submissionTime: Long) { - - var completionTime: Option[Long] = None - - val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty - - val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer() - - val driverAccumUpdates: mutable.HashMap[Long, Long] = mutable.HashMap.empty - - /** - * Return whether there are running jobs in this execution. - */ - def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) - - /** - * Return whether there are any failed jobs in this execution. - */ - def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) - - def runningJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq - - def succeededJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq - - def failedJobs: Seq[Long] = - jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq -} - -/** - * Represent a metric in a SQLPlan. - * - * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator - * updates for each task. So that if a task is retried, we can simply override the old updates with - * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use - * "AccumulatorParam" to get the aggregation value. - */ -private[ui] case class SQLPlanMetric( - name: String, - accumulatorId: Long, - metricType: String) - -/** - * Store all accumulatorUpdates for all tasks in a Spark stage. - */ -private[ui] class SQLStageMetrics( - val stageAttemptId: Long, - val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) - - -// TODO Should add attemptId here when we can get it from SparkListenerExecutorMetricsUpdate -/** - * Store all accumulatorUpdates for a Spark task. - */ -private[ui] class SQLTaskMetrics( - var finished: Boolean, - var accumulatorUpdates: Seq[(Long, Any)]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index d0376af3e31ca..a321a22f10789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} -class SQLTab(val listener: SQLListener, sparkUI: SparkUI) +class SQLTab(val sqlStore: SQLAppStatusStore, sparkUI: SparkUI) extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index ad9db308b2627..3e479faed72ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -83,11 +82,6 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { */ val cacheManager: CacheManager = new CacheManager - /** - * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. - */ - val listener: SQLListener = createListenerAndUI(sparkContext) - /** * A catalog that interacts with external systems. */ @@ -142,19 +136,6 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { val jarClassLoader = new NonClosableMutableURLClassLoader( org.apache.spark.util.Utils.getContextOrSparkClassLoader) - /** - * Create a SQLListener then add it into SparkContext, and create a SQLTab if there is SparkUI. - */ - private def createListenerAndUI(sc: SparkContext): SQLListener = { - if (SparkSession.sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (SparkSession.sqlListener.compareAndSet(null, listener)) { - sc.listenerBus.addToStatusQueue(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - SparkSession.sqlListener.get() - } } object SharedState extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 58a194b8af62b..d588af3e19dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -32,6 +33,13 @@ import org.apache.spark.util.{AccumulatorContext, JsonProtocol} class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext { import testImplicits._ + private def statusStore: SQLAppStatusStore = { + new SQLAppStatusStore(sparkContext.statusStore.store) + } + + private def currentExecutionIds(): Set[Long] = { + statusStore.executionsList.map(_.executionId).toSet + } /** * Generates a `DataFrame` by filling randomly generated bytes for hash collision. @@ -420,21 +428,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs val data = person.select('name) - sparkContext.listenerBus.waitUntilEmpty(10000) - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() // Assume the execution plan is // PhysicalRDD(nodeId = 0) data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs + val jobs = statusStore.execution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val metricValues = statusStore.executionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 3966e98c1ce06..d89c4b14619fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -34,6 +34,14 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ + private def statusStore: SQLAppStatusStore = { + new SQLAppStatusStore(sparkContext.statusStore.store) + } + + private def currentExecutionIds(): Set[Long] = { + statusStore.executionsList.map(_.executionId).toSet + } + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -41,24 +49,23 @@ trait SQLMetricsTestUtils extends SQLTestUtils { * @param func the function can produce execution id after running. */ private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = { - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() // Run the given function to trigger query execution. func spark.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size == 1) val executionId = executionIds.head - val executionData = spark.sharedState.listener.getExecution(executionId).get - val executedNode = executionData.physicalPlanGraph.nodes.head + val executionData = statusStore.execution(executionId).get + val executedNode = statusStore.planGraph(executionId).nodes.head val metricsNames = Seq( "number of written files", "number of dynamic part", "number of output rows") - val metrics = spark.sharedState.listener.getExecutionMetrics(executionId) + val metrics = statusStore.executionMetrics(executionId) metricsNames.zip(metricsValues).foreach { case (metricsName, expected) => val sqlMetric = executedNode.metrics.find(_.name == metricsName) @@ -134,22 +141,21 @@ trait SQLMetricsTestUtils extends SQLTestUtils { expectedNumOfJobs: Int, expectedNodeIds: Set[Long], enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { - val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet + val previousExecutionIds = currentExecutionIds() withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = - spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = currentExecutionIds().diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs + val jobs = statusStore.execution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val metricValues = statusStore.executionMetrics(executionId) val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedNodeIds.contains(node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 1055f09f5411c..eba8d55daad58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.ui import java.util.Properties +import scala.collection.mutable.ListBuffer + import org.json4s.jackson.JsonMethods._ -import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ @@ -36,13 +36,14 @@ import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.ui.SparkUI +import org.apache.spark.status.config._ import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} - +import org.apache.spark.util.kvstore.InMemoryStore class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ - import org.apache.spark.AccumulatorSuite.makeInfo + + override protected def sparkConf = super.sparkConf.set(LIVE_ENTITY_UPDATE_PERIOD, 0L) private def createTestDataFrame: DataFrame = { Seq( @@ -68,44 +69,67 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest details = "" ) - private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( - taskId = taskId, - attemptNumber = attemptNumber, - // The following fields are not used in tests - index = 0, - launchTime = 0, - executorId = "", - host = "", - taskLocality = null, - speculative = false - ) + private def createTaskInfo( + taskId: Int, + attemptNumber: Int, + accums: Map[Long, Long] = Map()): TaskInfo = { + val info = new TaskInfo( + taskId = taskId, + attemptNumber = attemptNumber, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false) + info.markFinished(TaskState.FINISHED, 1L) + info.setAccumulables(createAccumulatorInfos(accums)) + info + } - private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = TaskMetrics.empty - accumulatorUpdates.foreach { case (id, update) => + private def createAccumulatorInfos(accumulatorUpdates: Map[Long, Long]): Seq[AccumulableInfo] = { + accumulatorUpdates.map { case (id, value) => val acc = new LongAccumulator - acc.metadata = AccumulatorMetadata(id, Some(""), true) - acc.add(update) - metrics.registerAccumulator(acc) + acc.metadata = AccumulatorMetadata(id, None, false) + acc.toInfo(Some(value), None) + }.toSeq + } + + /** Return the shared SQL store from the active SparkSession. */ + private def statusStore: SQLAppStatusStore = + new SQLAppStatusStore(spark.sparkContext.statusStore.store) + + /** + * Runs a test with a temporary SQLAppStatusStore tied to a listener bus. Events can be sent to + * the listener bus to update the store, and all data will be cleaned up at the end of the test. + */ + private def sqlStoreTest(name: String) + (fn: (SQLAppStatusStore, SparkListenerBus) => Unit): Unit = { + test(name) { + val store = new InMemoryStore() + val bus = new ReplayListenerBus() + val listener = new SQLAppStatusListener(sparkConf, store, true) + bus.addListener(listener) + val sqlStore = new SQLAppStatusStore(store, Some(listener)) + fn(sqlStore, bus) } - metrics } - test("basic") { + sqlStoreTest("basic") { (store, bus) => def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { assert(actual.size == expected.size) - expected.foreach { e => + expected.foreach { case (id, value) => // The values in actual can be SQL metrics meaning that they contain additional formatting // when converted to string. Verify that they start with the expected value. // TODO: this is brittle. There is no requirement that the actual string needs to start // with the accumulator value. - assert(actual.contains(e._1)) - val v = actual.get(e._1).get.trim - assert(v.startsWith(e._2.toString)) + assert(actual.contains(id)) + val v = actual.get(id).get.trim + assert(v.startsWith(value.toString), s"Wrong value for accumulator $id") } } - val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -118,7 +142,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest (id, accumulatorValue) }.toMap - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", @@ -126,9 +150,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - val executionUIData = listener.executionIdToData(0) - - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq( @@ -136,291 +158,270 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createStageInfo(1, 0) ), createProperties(executionId))) - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 0))) - assert(listener.getExecutionMetrics(0).isEmpty) + assert(store.executionMetrics(0).isEmpty) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Driver accumulator updates don't belong to this execution should be filtered and no // exception will be thrown. - listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + bus.postToAll(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 0, - createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo)) + (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2))) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) // Retrying a stage should reset the metrics - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 0, 1, createAccumulatorInfos(accumulatorUpdates)), + (1L, 0, 1, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Ignore the task end for the first attempt - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Finish two tasks - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) - listener.onTaskEnd(SparkListenerTaskEnd( + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2)), + null)) + bus.postToAll(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(1, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) // Summit a new stage - listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + bus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 0))) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)) + (0L, 1, 0, createAccumulatorInfos(accumulatorUpdates)), + (1L, 1, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) // Finish two tasks - listener.onTaskEnd(SparkListenerTaskEnd( + bus.postToAll(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - listener.onTaskEnd(SparkListenerTaskEnd( + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) + bus.postToAll(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(1, 0), - createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + null)) - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) - assert(executionUIData.runningJobs === Seq(0)) - assert(executionUIData.succeededJobs.isEmpty) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), running = Seq(0)) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs === Seq(0)) - assert(executionUIData.failedJobs.isEmpty) - - checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + assertJobs(store.execution(0), completed = Seq(0)) + checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) } - test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before onJobEnd(JobSucceeded)") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs === Seq(0)) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), completed = Seq(0)) } - test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 1, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), JobSucceeded )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) - assert(executionUIData.failedJobs.isEmpty) + assertJobs(store.execution(0), completed = Seq(0, 1)) } - test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(spark.sparkContext.conf) + sqlStoreTest("onExecutionEnd happens before onJobEnd(JobFailed)") { (store, bus) => val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + bus.postToAll(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - listener.onJobStart(SparkListenerJobStart( + bus.postToAll(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( + bus.postToAll(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - listener.onJobEnd(SparkListenerJobEnd( + bus.postToAll(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobFailed(new RuntimeException("Oops")) )) - val executionUIData = listener.executionIdToData(0) - assert(executionUIData.runningJobs.isEmpty) - assert(executionUIData.succeededJobs.isEmpty) - assert(executionUIData.failedJobs === Seq(0)) + assertJobs(store.execution(0), failed = Seq(0)) } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = spark.sharedState.listener.stageIdToStageMetrics.size + val previousStageNumber = statusStore.executionsList().size spark.sparkContext.parallelize(1 to 10).foreach(i => ()) spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(statusStore.executionsList().size == previousStageNumber) spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber + 1) - } - - test("SPARK-13055: history listener only tracks SQL metrics") { - val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) - // We need to post other events for the listener to track our accumulators. - // These are largely just boilerplate unrelated to what we're trying to test. - val df = createTestDataFrame - val executionStart = SparkListenerSQLExecutionStart( - 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) - val stageInfo = createStageInfo(0, 0) - val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) - val stageSubmitted = SparkListenerStageSubmitted(stageInfo) - // This task has both accumulators that are SQL metrics and accumulators that are not. - // The listener should only track the ones that are actually SQL metrics. - val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") - val nonSqlMetric = sparkContext.longAccumulator("baseball") - val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None) - val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None) - val taskInfo = createTaskInfo(0, 0) - taskInfo.setAccumulables(List(sqlMetricInfo, nonSqlMetricInfo)) - val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) - listener.onOtherEvent(executionStart) - listener.onJobStart(jobStart) - listener.onStageSubmitted(stageSubmitted) - // Before SPARK-13055, this throws ClassCastException because the history listener would - // assume that the accumulator value is of type Long, but this may not be true for - // accumulators that are not SQL metrics. - listener.onTaskEnd(taskEnd) - val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => - stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) - } - // Listener tracks only SQL metrics, not other accumulators - assert(trackedAccums.size === 1) - assert(trackedAccums.head === ((sqlMetricInfo.id, sqlMetricInfo.update.get))) + assert(statusStore.executionsList().size == previousStageNumber + 1) } test("driver side SQL metrics") { - val listener = new SQLListener(spark.sparkContext.conf) - val expectedAccumValue = 12345 + val oldCount = statusStore.executionsList().size + val expectedAccumValue = 12345L val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) - sqlContext.sparkContext.addSparkListener(listener) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan } + SQLExecution.withNewExecutionId(spark, dummyQueryExecution) { physicalPlan.execute().collect() } - def waitTillExecutionFinished(): Unit = { - while (listener.getCompletedExecutions.isEmpty) { - Thread.sleep(100) + while (statusStore.executionsList().size < oldCount) { + Thread.sleep(100) + } + + // Wait for listener to finish computing the metrics for the execution. + while (statusStore.executionsList().last.metricValues == null) { + Thread.sleep(100) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val driverMetric = physicalPlan.metrics("dummy") + val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + + assert(metrics.contains(driverMetric.id)) + assert(metrics(driverMetric.id) === expectedValue) + } + + private def assertJobs( + exec: Option[SQLExecutionUIData], + running: Seq[Int] = Nil, + completed: Seq[Int] = Nil, + failed: Seq[Int] = Nil): Unit = { + + val actualRunning = new ListBuffer[Int]() + val actualCompleted = new ListBuffer[Int]() + val actualFailed = new ListBuffer[Int]() + + exec.get.jobs.foreach { case (jobId, jobStatus) => + jobStatus match { + case JobExecutionStatus.RUNNING => actualRunning += jobId + case JobExecutionStatus.SUCCEEDED => actualCompleted += jobId + case JobExecutionStatus.FAILED => actualFailed += jobId + case _ => fail(s"Unexpected status $jobStatus") } } - waitTillExecutionFinished() - val driverUpdates = listener.getCompletedExecutions.head.driverAccumUpdates - assert(driverUpdates.size == 1) - assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) + assert(actualRunning.toSeq.sorted === running) + assert(actualCompleted.toSeq.sorted === completed) + assert(actualFailed.toSeq.sorted === failed) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -490,7 +491,8 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe class SQLListenerMemoryLeakSuite extends SparkFunSuite { - test("no memory leak") { + // TODO: this feature is not yet available in SQLAppStatusStore. + ignore("no memory leak") { quietly { val conf = new SparkConf() .setMaster("local") @@ -498,7 +500,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly withSpark(new SparkContext(conf)) { sc => - SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ // Run 100 successful executions and 100 failed executions. @@ -516,12 +517,9 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(spark.sharedState.listener.getCompletedExecutions.size <= 50) - assert(spark.sharedState.listener.getFailedExecutions.size <= 50) - // 50 for successful executions and 50 for failed executions - assert(spark.sharedState.listener.executionIdToData.size <= 100) - assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) - assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) + + val statusStore = new SQLAppStatusStore(sc.statusStore.store) + assert(statusStore.executionsList().size <= 50) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e0568a3c5c99f..0b4629a51b425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -73,7 +73,6 @@ trait SharedSparkSession * call 'beforeAll'. */ protected def initializeSession(): Unit = { - SparkSession.sqlListener.set(null) if (_spark == null) { _spark = createSparkSession } From eaff295a232217c4424f2885303f9a127dea0422 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 14 Nov 2017 15:20:03 -0800 Subject: [PATCH 244/712] [SPARK-22519][YARN] Remove unnecessary stagingDirPath null check in ApplicationMaster.cleanupStagingDir() ## What changes were proposed in this pull request? Removed the unnecessary stagingDirPath null check in ApplicationMaster.cleanupStagingDir(). ## How was this patch tested? I verified with the existing test cases. Author: Devaraj K Closes #19749 from devaraj-kavali/SPARK-22519. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 244d912b9f3aa..ca0aa0ea3bc73 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -608,10 +608,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) - if (stagingDirPath == null) { - logError("Staging directory is null") - return - } logInfo("Deleting staging directory " + stagingDirPath) val fs = stagingDirPath.getFileSystem(yarnConf) fs.delete(stagingDirPath, true) From b009722591d2635698233c84f6e7e6cde7177019 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Nov 2017 17:58:07 -0600 Subject: [PATCH 245/712] [SPARK-22511][BUILD] Update maven central repo address ## What changes were proposed in this pull request? Use repo.maven.apache.org repo address; use latest ASF parent POM version 18 ## How was this patch tested? Existing tests; no functional change Author: Sean Owen Closes #19742 from srowen/SPARK-22511. --- dev/check-license | 2 +- pom.xml | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dev/check-license b/dev/check-license index 8cee09a53e087..b729f34f24059 100755 --- a/dev/check-license +++ b/dev/check-license @@ -20,7 +20,7 @@ acquire_rat_jar () { - URL="https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + URL="https://repo.maven.apache.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" JAR="$rat_jar" diff --git a/pom.xml b/pom.xml index 8570338df6878..0297311dd6e64 100644 --- a/pom.xml +++ b/pom.xml @@ -22,7 +22,7 @@ org.apache apache - 14 + 18 org.apache.spark spark-parent_2.11 @@ -111,6 +111,8 @@ UTF-8 UTF-8 1.8 + ${java.version} + ${java.version} 3.3.9 spark 1.7.16 @@ -226,7 +228,7 @@ central Maven Repository - https://repo1.maven.org/maven2 + https://repo.maven.apache.org/maven2 true @@ -238,7 +240,7 @@ central - https://repo1.maven.org/maven2 + https://repo.maven.apache.org/maven2 true From 774398045b7b0cde4afb3f3c1a19ad491cf71ed1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Nov 2017 16:48:26 -0800 Subject: [PATCH 246/712] [SPARK-21087][ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? We add a parameter whether to collect the full model list when CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change cause OOM) - Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to get the model list - CrossValidatorModelWriter add a “option”, allow user to control whether to persist the model list to disk (will persist by default). - Note: when persisting the model list, use indices as the sub-model path ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19208 from WeichenXu123/expose-model-list. --- .../ml/param/shared/SharedParamsCodeGen.scala | 8 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++ .../spark/ml/tuning/CrossValidator.scala | 137 ++++++++++++++++-- .../ml/tuning/TrainValidationSplit.scala | 128 ++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 19 +++ .../spark/ml/tuning/CrossValidatorSuite.scala | 54 ++++++- .../ml/tuning/TrainValidationSplitSuite.scala | 48 +++++- project/MimaExcludes.scala | 6 +- 8 files changed, 388 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 20a1db854e3a6..c54062921fce6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,7 +83,13 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " + + "sub-model will be available after fitting. If set to true, then all sub-models will be " + + "available. Warning: For large models, collecting all sub-models can cause OOMs on the " + + "Spark driver.", + Some("false"), isExpertParam = true) + ) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0d5fb28ae783c..34aa38ac751fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -468,4 +468,21 @@ trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param collectSubModels (default: false). + */ +private[ml] trait HasCollectSubModels extends Params { + + /** + * Param for whether to collect a list of sub-models trained during tuning. + * @group expertParam + */ + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning") + + setDefault(collectSubModels, false) + + /** @group expertGetParam */ + final def getCollectSubModels: Boolean = $(collectSubModels) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 7c81cb96e07f2..1682ca91bf832 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { @Since("1.2.0") class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with HasParallelism with MLWritable with Logging { + with CrossValidatorParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) @@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * for more information. + * + * @group expertSetParam + */ + @Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logParams(numFolds, seed, parallelism) logTuningParams(instr) + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) { + Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) + } else None + // Compute metrics for each model over each split val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => @@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(splitIndex)(paramIndex) = model + } + model } (executionContext) } @@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.4.0") @@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } + private var _subModels: Option[Array[Array[Model[_]]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) + : CrossValidatorModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in two dimension array. The index of outer array is the + * fold index, and the index of inner array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Array[Model[_]]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] ( val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone()) + avgMetrics.clone() + ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("1.6.0") - override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) + override def write: CrossValidatorModel.CrossValidatorModelWriter = { + new CrossValidatorModel.CrossValidatorModelWriter(this) + } } @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) + : Option[Array[Array[Model[_]]]] = { + subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))) + } + @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @Since("1.6.0") override def load(path: String): CrossValidatorModel = super.load(path) - private[CrossValidatorModel] - class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + /** + * Writer for CrossValidatorModel. + * @param instance CrossValidatorModel instance used to construct the writer + * + * CrossValidatorModelWriter supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class CrossValidatorModelWriter private[tuning] ( + instance: CrossValidatorModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (splitIndex <- 0 until instance.getNumFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } + } } } @@ -301,11 +395,30 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( + estimatorParamMaps.length)(null)) + for (splitIndex <- 0 until numFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + _subModels(splitIndex)(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + } + Some(_subModels) + } else None val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 6e3ad40706803..c73bd18475475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.tuning -import java.io.IOException -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -67,7 +66,8 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { @Since("1.5.0") class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with HasParallelism with MLWritable with Logging { + with TrainValidationSplitParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -101,6 +101,20 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter} + * for more information. + * + * @group expertSetParam + */@Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -121,12 +135,22 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.cache() validationDataset.cache() + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) { + Some(Array.fill[Model[_]](epm.length)(null)) + } else None + // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(paramIndex) = model + } + model } (executionContext) } @@ -158,7 +182,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.5.0") @@ -238,6 +263,30 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } + private var _subModels: Option[Array[Model[_]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) + : TrainValidationSplitModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in array. The index of array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Model[_]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -254,34 +303,73 @@ class TrainValidationSplitModel private[ml] ( val copied = new TrainValidationSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - validationMetrics.clone()) + validationMetrics.clone() + ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("2.0.0") - override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = { + new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + } } @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) + : Option[Array[Model[_]]] = { + subModels.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])) + } + @Since("2.0.0") override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader @Since("2.0.0") override def load(path: String): TrainValidationSplitModel = super.load(path) - private[TrainValidationSplitModel] - class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + /** + * Writer for TrainValidationSplitModel. + * @param instance TrainValidationSplitModel instance used to construct the writer + * + * TrainValidationSplitModel supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class TrainValidationSplitModelWriter private[tuning] ( + instance: TrainValidationSplitModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } } } @@ -298,8 +386,22 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Model[_]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + _subModels(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + Some(_subModels) + } else None val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7188da3531267..a616907800969 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.util import java.io.IOException +import java.util.Locale + +import scala.collection.mutable import org.apache.hadoop.fs.Path import org.json4s._ @@ -107,6 +110,22 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Map to store extra options for this writer. + */ + protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() + + /** + * Adds an option to the underlying MLWriter. See the documentation for the specific model's + * writer for possible options. The option name (key) is case-insensitive. + */ + @Since("2.3.0") + def option(key: String, value: String): this.type = { + require(key != null && !key.isEmpty) + optionMap.put(key.toLowerCase(Locale.ROOT), value) + this + } + /** * Overwrites if the output path already exists. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 853eeb39bf8df..15dade2627090 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -27,7 +29,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,6 +163,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val cv2 = testDefaultReadWrite(cv, testParams = false) @@ -168,6 +171,7 @@ class CrossValidatorSuite assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) assert(cv.getParallelism === cv2.getParallelism) + assert(cv.getCollectSubModels === cv2.getCollectSubModels) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -187,6 +191,54 @@ class CrossValidatorSuite .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) } + test("CrossValidator expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val numFolds = 3 + val subPath = new File(tempDir, "testCrossValidatorSubModels") + + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setParallelism(1) + .setCollectSubModels(true) + + val cvModel = cv.fit(dataset) + + assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds) + cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "cvModel3").getPath + cvModel.save(savingPathWithSubModels) + val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds) + cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath + cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + assert(!cvModel2.hasSubModels) + + for (i <- 0 until numFolds) { + for (j <- 0 until lrParamMaps.length) { + assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) + } + } + + val savingPathTestingIllegalParam = new File(subPath, "cvModel4").getPath + intercept[IllegalArgumentException] { + cvModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: CrossValidator with nested estimator") { val ova = new OneVsRest().setClassifier(new LogisticRegression) val evaluator = new MulticlassClassificationEvaluator() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index f8d9c66be2c40..9024342d9c831 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -26,7 +28,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,12 +163,14 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) assert(tvs.getParallelism === tvs2.getParallelism) + assert(tvs.getCollectSubModels === tvs2.getCollectSubModels) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) @@ -181,6 +185,48 @@ class TrainValidationSplitSuite } } + test("TrainValidationSplit expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val subPath = new File(tempDir, "testTrainValidationSplitSubModels") + + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setParallelism(1) + .setCollectSubModels(true) + + val tvsModel = tvs.fit(dataset) + + assert(tvsModel.hasSubModels && tvsModel.subModels.length == lrParamMaps.length) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath + tvsModel.save(savingPathWithSubModels) + val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == lrParamMaps.length) + + val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath + tvsModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + assert(!tvsModel2.hasSubModels) + + for (i <- 0 until lrParamMaps.length) { + assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) + } + + val savingPathTestingIllegalParam = new File(subPath, "tvsModel4").getPath + intercept[IllegalArgumentException] { + tvsModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: TrainValidationSplit with nested estimator") { val ova = new OneVsRest() .setClassifier(new LogisticRegression) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7f18b40f9d960..915c7e2e2fda3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -78,7 +78,11 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") ) // Exclude rules for 2.2.x From 1e6f760593d81def059c514d34173bf2777d71ec Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Nov 2017 16:58:18 -0800 Subject: [PATCH 247/712] [SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid ## What changes were proposed in this pull request? Support skip/error/keep strategy, similar to `StringIndexer`. Implemented via `try...catch`, so that it can avoid possible performance impact. ## How was this patch tested? Unit test added. Author: WeichenXu Closes #19588 from WeichenXu123/handle_invalid_for_vector_indexer. --- .../spark/ml/feature/VectorIndexer.scala | 92 +++++++++++++++++-- .../spark/ml/feature/VectorIndexerSuite.scala | 39 ++++++++ 2 files changed, 121 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index d371da762c55d..3403ec4259b86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.feature import java.lang.{Double => JDouble, Integer => JInt} -import java.util.{Map => JMap} +import java.util.{Map => JMap, NoSuchElementException} import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ @@ -37,7 +38,27 @@ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet /** Private trait for params for VectorIndexer and VectorIndexerModel */ -private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { + + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Note: this param only applies to categorical features, not continuous ones. + * Options are: + * 'skip': filter out rows with invalid data. + * 'error': throw an error. + * 'keep': put invalid data in a special additional bucket, at index numCategories. + * Default value: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) /** * Threshold for the number of values a categorical feature can take. @@ -113,6 +134,10 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) @@ -148,6 +173,11 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.6.0") object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): VectorIndexer = super.load(path) @@ -287,9 +317,15 @@ class VectorIndexerModel private[ml] ( while (featureIndex < numFeatures) { if (categoryMaps.contains(featureIndex)) { // categorical feature - val featureValues: Array[String] = + val rawFeatureValues: Array[String] = categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) - if (featureValues.length == 2) { + + val featureValues = if (getHandleInvalid == VectorIndexer.KEEP_INVALID) { + (rawFeatureValues.toList :+ "__unknown").toArray + } else { + rawFeatureValues + } + if (featureValues.length == 2 && getHandleInvalid != VectorIndexer.KEEP_INVALID) { attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), values = Some(featureValues)) } else { @@ -311,22 +347,39 @@ class VectorIndexerModel private[ml] ( // TODO: Check more carefully about whether this whole class will be included in a closure. /** Per-vector transform function */ - private val transformFunc: Vector => Vector = { + private lazy val transformFunc: Vector => Vector = { val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted val localVectorMap = categoryMaps val localNumFeatures = numFeatures + val localHandleInvalid = getHandleInvalid val f: Vector => Vector = { (v: Vector) => assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + s" $numFeatures but found length ${v.size}") v match { case dv: DenseVector => + var hasInvalid = false val tmpv = dv.copy localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => - tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + try { + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv(featureIndex)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(featureIndex) = categoryMap.size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } } - tmpv + if (hasInvalid) null else tmpv case sv: SparseVector => // We use the fact that categorical value 0 is always mapped to index 0. + var hasInvalid = false val tmpv = sv.copy var catFeatureIdx = 0 // index into sortedCatFeatureIndices var k = 0 // index into non-zero elements of sparse vector @@ -337,12 +390,26 @@ class VectorIndexerModel private[ml] ( } else if (featureIndex > tmpv.indices(k)) { k += 1 } else { - tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + try { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv.values(k)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(k) = localVectorMap(featureIndex).size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } catFeatureIdx += 1 k += 1 } } - tmpv + if (hasInvalid) null else tmpv } } f @@ -362,7 +429,12 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol, newField.metadata) + val ds = dataset.withColumn($(outputCol), newCol, newField.metadata) + if (getHandleInvalid == VectorIndexer.SKIP_INVALID) { + ds.na.drop(Array($(outputCol))) + } else { + ds + } } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index f2cca8aa82e85..69a7b75e32eb7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -38,6 +38,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // identical, of length 3 @transient var densePoints1: DataFrame = _ @transient var sparsePoints1: DataFrame = _ + @transient var densePoints1TestInvalid: DataFrame = _ + @transient var sparsePoints1TestInvalid: DataFrame = _ @transient var point1maxes: Array[Double] = _ // identical, of length 2 @@ -55,11 +57,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext Vectors.dense(0.0, 1.0, 2.0), Vectors.dense(0.0, 0.0, -1.0), Vectors.dense(1.0, 3.0, 2.0)) + val densePoints1SeqTestInvalid = densePoints1Seq ++ Seq( + Vectors.dense(10.0, 2.0, 0.0), + Vectors.dense(0.0, 10.0, 2.0), + Vectors.dense(1.0, 3.0, 10.0)) val sparsePoints1Seq = Seq( Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), Vectors.sparse(3, Array(2), Array(-1.0)), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + val sparsePoints1SeqTestInvalid = sparsePoints1Seq ++ Seq( + Vectors.sparse(3, Array(0, 1), Array(10.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(10.0, 2.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 10.0))) point1maxes = Array(1.0, 3.0, 2.0) val densePoints2Seq = Seq( @@ -88,6 +98,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext densePoints1 = densePoints1Seq.map(FeatureData).toDF() sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() + densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF() + sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF() densePoints2 = densePoints2Seq.map(FeatureData).toDF() sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() badPoints = badPointsSeq.map(FeatureData).toDF() @@ -219,6 +231,33 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) } + test("handle invalid") { + for ((points, pointsTestInvalid) <- Seq((densePoints1, densePoints1TestInvalid), + (sparsePoints1, sparsePoints1TestInvalid))) { + val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") + val model = vectorIndexer.fit(points) + intercept[SparkException] { + model.transform(pointsTestInvalid).collect() + } + val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") + val model1 = vectorIndexer1.fit(points) + val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) + assert(transformed1 === invalidTransformed1) + + val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") + val model2 = vectorIndexer2.fit(points) + val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + assert(invalidTransformed2 === transformed1 ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors.dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0)) + ) + } + } + test("Maintain sparsity for sparse vectors") { def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { val points = data.collect().map(_.getAs[Vector](0)) From dce1610ae376af00712ba7f4c99bfb4c006dbaec Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Nov 2017 14:42:37 +0100 Subject: [PATCH 248/712] [SPARK-22514][SQL] move ColumnVector.Array and ColumnarBatch.Row to individual files ## What changes were proposed in this pull request? Logically the `Array` doesn't belong to `ColumnVector`, and `Row` doesn't belong to `ColumnarBatch`. e.g. `ColumnVector` needs to return `Array` for `getArray`, and `Row` for `getStruct`. `Array` and `Row` can return each other with the `getArray`/`getStruct` methods. This is also a step to make `ColumnVector` public, it's cleaner to have `Array` and `Row` as top-level classes. This PR is just code moving around, with 2 renaming: `Array` -> `VectorBasedArray`, `Row` -> `VectorBasedRow`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19740 from cloud-fan/vector. --- .../vectorized/AggregateHashMap.java | 2 +- .../vectorized/ArrowColumnVector.java | 6 +- .../execution/vectorized/ColumnVector.java | 202 +---------- .../vectorized/ColumnVectorUtils.java | 2 +- .../execution/vectorized/ColumnarArray.java | 208 +++++++++++ .../execution/vectorized/ColumnarBatch.java | 326 +---------------- .../sql/execution/vectorized/ColumnarRow.java | 327 ++++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 14 +- .../aggregate/HashAggregateExec.scala | 10 +- .../VectorizedHashMapGenerator.scala | 12 +- .../vectorized/ColumnVectorSuite.scala | 40 +-- 13 files changed, 597 insertions(+), 556 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index cb3ad4eab1f60..9467435435d1f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -72,7 +72,7 @@ public AggregateHashMap(StructType schema) { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarBatch.Row findOrInsert(long key) { + public ColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 51ea719f8c4a6..949035bfb177c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -251,7 +251,7 @@ public int getArrayOffset(int rowId) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { throw new UnsupportedOperationException(); } @@ -330,7 +330,7 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new ColumnVector.Array(childColumns[0]); + resultArray = new ColumnarArray(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; accessor = new StructAccessor(mapVector); @@ -339,7 +339,7 @@ public ArrowColumnVector(ValueVector vector) { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarBatch.Row(childColumns); + resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c4b519f0b153f..666fd63fdcf2f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,11 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; /** @@ -42,190 +40,6 @@ * ColumnVectors are intended to be reused. */ public abstract class ColumnVector implements AutoCloseable { - - /** - * Holder object to return an array. This object is intended to be reused. Callers should - * copy the data out if it needs to be stored. - */ - public static final class Array extends ArrayData { - // The data for this array. This array contains elements from - // data[offset] to data[offset + length). - public final ColumnVector data; - public int length; - public int offset; - - // Populate if binary data is required for the Array. This is stored here as an optimization - // for string data. - public byte[] byteArray; - public int byteArrayOffset; - - // Reused staging buffer, used for loading from offheap. - protected byte[] tmpByteArray = new byte[1]; - - protected Array(ColumnVector data) { - this.data = data; - } - - @Override - public int numElements() { return length; } - - @Override - public ArrayData copy() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } - - @Override - public byte[] toByteArray() { return data.getBytes(offset, length); } - - @Override - public short[] toShortArray() { return data.getShorts(offset, length); } - - @Override - public int[] toIntArray() { return data.getInts(offset, length); } - - @Override - public long[] toLongArray() { return data.getLongs(offset, length); } - - @Override - public float[] toFloatArray() { return data.getFloats(offset, length); } - - @Override - public double[] toDoubleArray() { return data.getDoubles(offset, length); } - - // TODO: this is extremely expensive. - @Override - public Object[] array() { - DataType dt = data.dataType(); - Object[] list = new Object[length]; - try { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = get(i, dt); - } - } - return list; - } catch(Exception e) { - throw new RuntimeException("Could not get the array", e); - } - } - - @Override - public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } - - @Override - public boolean getBoolean(int ordinal) { - return data.getBoolean(offset + ordinal); - } - - @Override - public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } - - @Override - public short getShort(int ordinal) { - return data.getShort(offset + ordinal); - } - - @Override - public int getInt(int ordinal) { return data.getInt(offset + ordinal); } - - @Override - public long getLong(int ordinal) { return data.getLong(offset + ordinal); } - - @Override - public float getFloat(int ordinal) { - return data.getFloat(offset + ordinal); - } - - @Override - public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - return data.getDecimal(offset + ordinal, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - return data.getUTF8String(offset + ordinal); - } - - @Override - public byte[] getBinary(int ordinal) { - return data.getBinary(offset + ordinal); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return data.getStruct(offset + ordinal); - } - - @Override - public ArrayData getArray(int ordinal) { - return data.getArray(offset + ordinal); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else if (dataType instanceof CalendarIntervalType) { - return getInterval(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } - - @Override - public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } - } - /** * Returns the data type of this column. */ @@ -350,7 +164,7 @@ public Object get(int ordinal, DataType dataType) { /** * Returns a utility object to get structs. */ - public ColumnarBatch.Row getStruct(int rowId) { + public ColumnarRow getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -359,7 +173,7 @@ public ColumnarBatch.Row getStruct(int rowId) { * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarBatch.Row getStruct(int rowId, int size) { + public ColumnarRow getStruct(int rowId, int size) { resultStruct.rowId = rowId; return resultStruct; } @@ -367,7 +181,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final ColumnVector.Array getArray(int rowId) { + public final ColumnarArray getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -376,7 +190,7 @@ public final ColumnVector.Array getArray(int rowId) { /** * Loads the data into array.byteArray. */ - public abstract void loadBytes(ColumnVector.Array array); + public abstract void loadBytes(ColumnarArray array); /** * Returns the value for rowId. @@ -423,12 +237,12 @@ public MapData getMap(int ordinal) { /** * Reusable Array holder for getArray(). */ - protected ColumnVector.Array resultArray; + protected ColumnarArray resultArray; /** * Reusable Struct holder for getStruct(). */ - protected ColumnarBatch.Row resultStruct; + protected ColumnarRow resultStruct; /** * The Dictionary for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index adb859ed17757..b4b5f0a265934 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -98,7 +98,7 @@ public static void populate(WritableColumnVector col, InternalRow row, int field * For example, an array of IntegerType will return an int[]. * Throws exceptions for unhandled schemas. */ - public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + public static Object toPrimitiveJavaArray(ColumnarArray array) { DataType dt = array.data.dataType(); if (dt instanceof IntegerType) { int[] result = new int[array.length]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java new file mode 100644 index 0000000000000..5e88ce0321084 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Array abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarArray extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected ColumnarArray(ColumnVector data) { + this.data = data; + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index bc546c7c425b1..8849a20d6ceb5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,17 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; import java.util.*; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.types.StructType; /** * This class is the in memory representation of rows as they are streamed through operators. It @@ -48,7 +40,7 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - private final ColumnVector[] columns; + final ColumnVector[] columns; // True if the row is filtered. private final boolean[] filteredRows; @@ -60,7 +52,7 @@ public final class ColumnarBatch { private int numRowsFiltered = 0; // Staging row returned from getRow. - final Row row; + final ColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -72,313 +64,13 @@ public void close() { } } - /** - * Adapter class to interop with existing components that expect internal row. A lot of - * performance is lost with this translation. - */ - public static final class Row extends InternalRow { - protected int rowId; - private final ColumnarBatch parent; - private final int fixedLenRowSize; - private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; - - // Ctor used if this is a top level row. - private Row(ColumnarBatch parent) { - this.parent = parent; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); - this.columns = parent.columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - // Ctor used if this is a struct. - protected Row(ColumnVector[] columns) { - this.parent = null; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); - this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered() { - parent.markFiltered(rowId); - } - - public ColumnVector[] columns() { return columns; } - - @Override - public int numFields() { return columns.length; } - - @Override - /** - * Revisit this. This is expensive. This is currently only used in test paths. - */ - public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); - for (int i = 0; i < numFields(); i++) { - if (isNullAt(i)) { - row.setNullAt(i); - } else { - DataType dt = columns[i].dataType(); - if (dt instanceof BooleanType) { - row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { - row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { - row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof LongType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { - row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { - row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { - row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; - row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else { - throw new RuntimeException("Not implemented. " + dt); - } - } - } - return row; - } - - @Override - public boolean anyNull() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } - - @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } - - @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } - - @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } - - @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } - - @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } - - @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); - } - - @Override - public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); - } - - @Override - public ArrayData getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } - - @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } - - @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } - } - /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator rowIterator() { + public Iterator rowIterator() { final int maxRows = ColumnarBatch.this.numRows(); - final Row row = new Row(this); - return new Iterator() { + final ColumnarRow row = new ColumnarRow(this); + return new Iterator() { int rowId = 0; @Override @@ -390,7 +82,7 @@ public boolean hasNext() { } @Override - public Row next() { + public ColumnarRow next() { while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { ++rowId; } @@ -491,7 +183,7 @@ public void setColumn(int ordinal, ColumnVector column) { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarBatch.Row getRow(int rowId) { + public ColumnarRow getRow(int rowId) { assert(rowId >= 0); assert(rowId < numRows); row.rowId = rowId; @@ -522,6 +214,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.capacity = capacity; this.nullFilteredColumns = new HashSet<>(); this.filteredRows = new boolean[capacity]; - this.row = new Row(this); + this.row = new ColumnarRow(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java new file mode 100644 index 0000000000000..c75adafd69461 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Row abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarRow extends InternalRow { + protected int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; + + // Ctor used if this is a top level row. + ColumnarRow(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + // Ctor used if this is a struct. + ColumnarRow(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public void markFiltered() { + parent.markFiltered(rowId); + } + + public ColumnVector[] columns() { return columns; } + + @Override + public int numFields() { return columns.length; } + + /** + * Revisit this. This is expensive. This is currently only used in test paths. + */ + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + getWritableColumn(ordinal).putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDecimal(rowId, value, precision); + } + + private WritableColumnVector getWritableColumn(int ordinal) { + WritableColumnVector column = writableColumns[ordinal]; + assert (!column.isConstant); + return column; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7522ebf5821a..2bf523b7e7198 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -523,7 +523,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 166a39e0fabd9..d699d292711dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -494,7 +494,7 @@ public void putArray(int rowId, int offset, int length) { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d3a14b9d8bd74..96cfeed34f300 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -283,8 +283,8 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - private ColumnVector.Array getByteArray(int rowId) { - ColumnVector.Array array = getArray(rowId); + private ColumnarArray getByteArray(int rowId) { + ColumnarArray array = getArray(rowId); array.data.loadBytes(array); return array; } @@ -324,7 +324,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { if (dictionary == null) { - ColumnVector.Array a = getByteArray(rowId); + ColumnarArray a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } else { byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); @@ -338,7 +338,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { if (dictionary == null) { - ColumnVector.Array array = getByteArray(rowId); + ColumnarArray array = getByteArray(rowId); byte[] bytes = new byte[array.length]; System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); return bytes; @@ -685,7 +685,7 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultArray = new ColumnVector.Array(this.childColumns[0]); + this.resultArray = new ColumnarArray(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; @@ -694,14 +694,14 @@ protected WritableColumnVector(int capacity, DataType type) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2a208a2722550..51f7c9e22b902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,7 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( - "java.util.Iterator", + "java.util.Iterator", iterTermForFastHashMap, "") } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, @@ -681,7 +681,7 @@ case class HashAggregateExec( """ } - // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -697,8 +697,8 @@ case class HashAggregateExec( s""" | while ($iterTermForFastHashMap.hasNext()) { | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | org.apache.spark.sql.execution.vectorized.ColumnarRow $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarRow) | $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -892,7 +892,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $fastRowBuffer = null; + | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; """.stripMargin } else { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 812d405d5ebfe..fd783d905b776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ /** @@ -142,14 +142,14 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( * long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; @@ -189,7 +189,7 @@ class VectorizedHashMapGenerator( } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ groupingKeySignature}) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; @@ -229,7 +229,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator + |public java.util.Iterator | rowIterator() { | return batch.rowIterator(); |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index c5c8ae3a17c6c..3c76ca79f5dda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -57,7 +57,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendBoolean(i % 2 == 0) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, BooleanType) === (i % 2 == 0)) @@ -69,7 +69,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByte(i.toByte) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ByteType) === i.toByte) @@ -81,7 +81,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendShort(i.toShort) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ShortType) === i.toShort) @@ -93,7 +93,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendInt(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, IntegerType) === i) @@ -105,7 +105,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendLong(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, LongType) === i) @@ -117,7 +117,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendFloat(i.toFloat) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, FloatType) === i.toFloat) @@ -129,7 +129,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendDouble(i.toDouble) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, DoubleType) === i.toDouble) @@ -142,7 +142,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) @@ -155,7 +155,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") @@ -179,12 +179,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.putArray(2, 3, 0) testVector.putArray(3, 3, 3) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) - assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) - assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) - assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + assert(array.getArray(0).toIntArray() === Array(0)) + assert(array.getArray(1).toIntArray() === Array(1, 2)) + assert(array.getArray(2).toIntArray() === Array.empty[Int]) + assert(array.getArray(3).toIntArray() === Array(3, 4, 5)) } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) @@ -196,12 +196,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { c1.putInt(1, 456) c2.putDouble(1, 5.67) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.getStruct(0, structType.length).get(0, IntegerType) === 123) + assert(array.getStruct(0, structType.length).get(1, DoubleType) === 3.45) + assert(array.getStruct(1, structType.length).get(0, IntegerType) === 456) + assert(array.getStruct(1, structType.length).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { @@ -214,7 +214,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.reserve(16) // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 8).foreach { i => assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) } From 8f0e88df03a06a91bb61c6e0d69b1b19e2bfb3f7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 15 Nov 2017 23:35:13 +0900 Subject: [PATCH 249/712] [SPARK-20791][PYTHON][FOLLOWUP] Check for unicode column names in createDataFrame with Arrow ## What changes were proposed in this pull request? If schema is passed as a list of unicode strings for column names, they should be re-encoded to 'utf-8' to be consistent. This is similar to the #13097 but for creation of DataFrame using Arrow. ## How was this patch tested? Added new test of using unicode names for schema. Author: Bryan Cutler Closes #19738 from BryanCutler/arrow-createDataFrame-followup-unicode-SPARK-20791. --- python/pyspark/sql/session.py | 7 ++++--- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 589365b083012..dbbcfff6db91b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -592,6 +592,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr if isinstance(schema, basestring): schema = _parse_datatype_string(schema) + elif isinstance(schema, (list, tuple)): + # Must re-encode any unicode strings to be consistent with StructField names + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] try: import pandas @@ -602,7 +605,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [str(x) for x in data.columns] + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: @@ -630,8 +633,6 @@ def prepare(obj): verify_func(obj) return obj, else: - if isinstance(schema, (list, tuple)): - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj if isinstance(data, RDD): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6356d938db26a..ef592c2356a8c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3225,6 +3225,16 @@ def test_createDataFrame_with_names(self): df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + def test_createDataFrame_column_name_encoding(self): + import pandas as pd + pdf = pd.DataFrame({u'a': [1]}) + columns = self.spark.createDataFrame(pdf).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'a') + columns = self.spark.createDataFrame(pdf, [u'b']).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'b') + def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): From 7f99a05e6ff258fc2192130451aa8aa1304bfe93 Mon Sep 17 00:00:00 2001 From: test Date: Wed, 15 Nov 2017 10:13:01 -0600 Subject: [PATCH 250/712] [SPARK-22422][ML] Add Adjusted R2 to RegressionMetrics ## What changes were proposed in this pull request? I added adjusted R2 as a regression metric which was implemented in all major statistical analysis tools. In practice, no one looks at R2 alone. The reason is R2 itself is misleading. If we add more parameters, R2 will not decrease but only increase (or stay the same). This leads to overfitting. Adjusted R2 addressed this issue by using number of parameters as "weight" for the sum of errors. ## How was this patch tested? - Added a new unit test and passed. - ./dev/run-tests all passed. Author: test Author: tengpeng Closes #19638 from tengpeng/master. --- .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ .../ml/regression/LinearRegressionSuite.scala | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index df1aa609c1b71..da6bcf07e4742 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -722,6 +722,21 @@ class LinearRegressionSummary private[regression] ( @Since("1.5.0") val r2: Double = metrics.r2 + /** + * Returns Adjusted R^2^, the adjusted coefficient of determination. + * Reference: + * Wikipedia coefficient of determination + * + * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. + * This will change in later Spark versions. + */ + @Since("2.3.0") + val r2adj: Double = { + val interceptDOF = if (privateModel.getFitIntercept) 1 else 0 + 1 - (1 - r2) * (numInstances - interceptDOF) / + (numInstances - privateModel.coefficients.size - interceptDOF) + } + /** Residuals (label - predicted value) */ @Since("1.5.0") @transient lazy val residuals: DataFrame = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index f470dca7dbd0a..0e0be58dbf022 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -764,6 +764,11 @@ class LinearRegressionSuite (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** V2 4.6982442 0.0011805 3980 <2e-16 *** V3 7.1994344 0.0009044 7961 <2e-16 *** + + # R code for r2adj + lm_fit <- lm(V1 ~ V2 + V3, data = d1) + summary(lm_fit)$adj.r.squared + [1] 0.9998736 --- .... @@ -771,6 +776,7 @@ class LinearRegressionSuite assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) + assert(model.summary.r2adj ~== 0.9998736 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". If no regularization is applied or only L2 // regularization is applied, this algorithm uses a direct solver and does not generate an From aa88b8dbbb7e71b282f31ae775140c783e83b4d6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 15 Nov 2017 08:59:29 -0800 Subject: [PATCH 251/712] [SPARK-22490][DOC] Add PySpark doc for SparkSession.builder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In PySpark API Document, [SparkSession.build](http://spark.apache.org/docs/2.2.0/api/python/pyspark.sql.html) is not documented and shows default value description. ``` SparkSession.builder =
    >
    > builder
    >

    A class attribute having a Builder to construct SparkSession instances

    >
    > 212,216d217 <
    < builder = <pyspark.sql.session.SparkSession.Builder object>
    <
    < <
    ``` ## How was this patch tested? Manual. ``` cd python/docs make html open _build/html/pyspark.sql.html ``` Author: Dongjoon Hyun Closes #19726 from dongjoon-hyun/SPARK-22490. --- python/docs/pyspark.sql.rst | 3 +++ python/pyspark/sql/session.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 09848b880194d..5c3b7e274857a 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -7,6 +7,9 @@ Module Context .. automodule:: pyspark.sql :members: :undoc-members: + :exclude-members: builder +.. We need `exclude-members` to prevent default description generations + as a workaround for old Sphinx (< 1.6.6). pyspark.sql.types module ------------------------ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index dbbcfff6db91b..47c58bb28221c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -72,6 +72,9 @@ class SparkSession(object): ... .appName("Word Count") \\ ... .config("spark.some.config.option", "some-value") \\ ... .getOrCreate() + + .. autoattribute:: builder + :annotation: """ class Builder(object): @@ -183,6 +186,7 @@ def getOrCreate(self): return session builder = Builder() + """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" _instantiatedSession = None From bc0848b4c1ab84ccef047363a70fd11df240dbbf Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 15 Nov 2017 09:02:54 -0800 Subject: [PATCH 252/712] [SPARK-22469][SQL] Accuracy problem in comparison with string and numeric ## What changes were proposed in this pull request? This fixes a problem caused by #15880 `select '1.5' > 0.5; // Result is NULL in Spark but is true in Hive. ` When compare string and numeric, cast them as double like Hive. Author: liutang123 Closes #19692 from liutang123/SPARK-22469. --- .../sql/catalyst/analysis/TypeCoercion.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 3 + .../sql-tests/inputs/predicate-functions.sql | 5 + .../results/predicate-functions.sql.out | 140 +++++++++++------- 4 files changed, 105 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 532d22dbf2321..074eda56199e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -137,6 +137,13 @@ object TypeCoercion { case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + + // There is no proper decimal type we can pick, + // using double type is the best we can do. + // See SPARK-22469 for details. + case (n: DecimalType, s: StringType) => Some(DoubleType) + case (s: StringType, n: DecimalType) => Some(DoubleType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) case (l, r) => None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 793e04f66f0f9..5dcd653e9b341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1152,6 +1152,9 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(PromoteStrings, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) + ruleTest(PromoteStrings, + GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index 3b3d4ad64b3ec..e99d5cef81f64 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -2,12 +2,14 @@ select 1 = 1; select 1 = '1'; select 1.0 = '1'; +select 1.5 = '1.51'; -- GreaterThan select 1 > '1'; select 2 > '1.0'; select 2 > '2.0'; select 2 > '2.2'; +select '1.5' > 0.5; select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'; @@ -16,6 +18,7 @@ select 1 >= '1'; select 2 >= '1.0'; select 2 >= '2.0'; select 2.0 >= '2.2'; +select '1.5' >= 0.5; select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'; @@ -24,6 +27,7 @@ select 1 < '1'; select 2 < '1.0'; select 2 < '2.0'; select 2.0 < '2.2'; +select 0.5 < '1.5'; select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'; @@ -32,5 +36,6 @@ select 1 <= '1'; select 2 <= '1.0'; select 2 <= '2.0'; select 2.0 <= '2.2'; +select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index 8e7e04c8e1c4f..8cd0d51da64f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 31 -- !query 0 @@ -21,11 +21,19 @@ true -- !query 2 select 1.0 = '1' -- !query 2 schema -struct<(1.0 = CAST(1 AS DECIMAL(2,1))):boolean> +struct<(CAST(1.0 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> -- !query 2 output true +-- !query 3 +select 1.5 = '1.51' +-- !query 3 schema +struct<(CAST(1.5 AS DOUBLE) = CAST(1.51 AS DOUBLE)):boolean> +-- !query 3 output +false + + -- !query 3 select 1 > '1' -- !query 3 schema @@ -59,160 +67,192 @@ false -- !query 7 -select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') +select '1.5' > 0.5 -- !query 7 schema -struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean> -- !query 7 output -false +true -- !query 8 -select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') -- !query 8 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> -- !query 8 output false -- !query 9 -select 1 >= '1' +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' -- !query 9 schema -struct<(1 >= CAST(1 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> -- !query 9 output -true +false -- !query 10 -select 2 >= '1.0' +select 1 >= '1' -- !query 10 schema -struct<(2 >= CAST(1.0 AS INT)):boolean> +struct<(1 >= CAST(1 AS INT)):boolean> -- !query 10 output true -- !query 11 -select 2 >= '2.0' +select 2 >= '1.0' -- !query 11 schema -struct<(2 >= CAST(2.0 AS INT)):boolean> +struct<(2 >= CAST(1.0 AS INT)):boolean> -- !query 11 output true -- !query 12 -select 2.0 >= '2.2' +select 2 >= '2.0' -- !query 12 schema -struct<(2.0 >= CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(2 >= CAST(2.0 AS INT)):boolean> -- !query 12 output -false +true -- !query 13 -select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') +select 2.0 >= '2.2' -- !query 13 schema -struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean> -- !query 13 output -true +false -- !query 14 -select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' +select '1.5' >= 0.5 -- !query 14 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> +struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean> -- !query 14 output -false +true -- !query 15 -select 1 < '1' +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') -- !query 15 schema -struct<(1 < CAST(1 AS INT)):boolean> +struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> -- !query 15 output -false +true -- !query 16 -select 2 < '1.0' +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' -- !query 16 schema -struct<(2 < CAST(1.0 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> -- !query 16 output false -- !query 17 -select 2 < '2.0' +select 1 < '1' -- !query 17 schema -struct<(2 < CAST(2.0 AS INT)):boolean> +struct<(1 < CAST(1 AS INT)):boolean> -- !query 17 output false -- !query 18 -select 2.0 < '2.2' +select 2 < '1.0' -- !query 18 schema -struct<(2.0 < CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(2 < CAST(1.0 AS INT)):boolean> -- !query 18 output -true +false -- !query 19 -select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') +select 2 < '2.0' -- !query 19 schema -struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +struct<(2 < CAST(2.0 AS INT)):boolean> -- !query 19 output false -- !query 20 -select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' +select 2.0 < '2.2' -- !query 20 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> +struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean> -- !query 20 output true -- !query 21 -select 1 <= '1' +select 0.5 < '1.5' -- !query 21 schema -struct<(1 <= CAST(1 AS INT)):boolean> +struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean> -- !query 21 output true -- !query 22 -select 2 <= '1.0' +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') -- !query 22 schema -struct<(2 <= CAST(1.0 AS INT)):boolean> +struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> -- !query 22 output false -- !query 23 -select 2 <= '2.0' +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' -- !query 23 schema -struct<(2 <= CAST(2.0 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> -- !query 23 output true -- !query 24 -select 2.0 <= '2.2' +select 1 <= '1' -- !query 24 schema -struct<(2.0 <= CAST(2.2 AS DECIMAL(2,1))):boolean> +struct<(1 <= CAST(1 AS INT)):boolean> -- !query 24 output true -- !query 25 -select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +select 2 <= '1.0' -- !query 25 schema -struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +struct<(2 <= CAST(1.0 AS INT)):boolean> -- !query 25 output -true +false -- !query 26 -select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +select 2 <= '2.0' -- !query 26 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +struct<(2 <= CAST(2.0 AS INT)):boolean> -- !query 26 output true + + +-- !query 27 +select 2.0 <= '2.2' +-- !query 27 schema +struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean> +-- !query 27 output +true + + +-- !query 28 +select 0.5 <= '1.5' +-- !query 28 schema +struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean> +-- !query 28 output +true + + +-- !query 29 +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +-- !query 29 schema +struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +-- !query 29 output +true + + +-- !query 30 +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +-- !query 30 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +-- !query 30 output +true From 39b3f10dda73f4a1f735f17467e5c6c45c44e977 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 15 Nov 2017 15:41:53 -0600 Subject: [PATCH 253/712] [SPARK-20649][CORE] Simplify REST API resource structure. With the new UI store, the API resource classes have a lot less code, since there's no need for complicated translations between the UI types and the API types. So the code ended up with a bunch of files with a single method declared in them. This change re-structures the API code so that it uses less classes; mainly, most sub-resources were removed, and the code to deal with single-attempt and multi-attempt apps was simplified. The only change was the addition of a method to return a single attempt's information; that was missing in the old API, so trying to retrieve "/v1/applications/appId/attemptId" would result in a 404 even if the attempt existed (and URIs under that one would return valid data). The streaming API resources also overtook the same treatment, even though the data is not stored in the new UI store. Author: Marcelo Vanzin Closes #19748 from vanzin/SPARK-20649. --- .../api/v1/AllExecutorListResource.scala | 30 --- .../spark/status/api/v1/AllJobsResource.scala | 35 --- .../spark/status/api/v1/AllRDDResource.scala | 31 --- .../status/api/v1/AllStagesResource.scala | 33 --- .../spark/status/api/v1/ApiRootResource.scala | 203 ++---------------- .../v1/ApplicationEnvironmentResource.scala | 32 --- .../api/v1/ApplicationListResource.scala | 2 +- .../api/v1/EventLogDownloadResource.scala | 71 ------ .../status/api/v1/ExecutorListResource.scala | 30 --- .../api/v1/OneApplicationResource.scala | 146 ++++++++++++- .../spark/status/api/v1/OneJobResource.scala | 38 ---- .../spark/status/api/v1/OneRDDResource.scala | 38 ---- ...ageResource.scala => StagesResource.scala} | 34 +-- .../spark/status/api/v1/VersionResource.scala | 30 --- .../api/v1/streaming/AllBatchesResource.scala | 78 ------- .../AllOutputOperationsResource.scala | 66 ------ .../v1/streaming/AllReceiversResource.scala | 76 ------- .../api/v1/streaming/ApiStreamingApp.scala | 31 ++- .../streaming/ApiStreamingRootResource.scala | 172 ++++++++++++--- .../api/v1/streaming/OneBatchResource.scala | 35 --- .../OneOutputOperationResource.scala | 39 ---- .../v1/streaming/OneReceiverResource.scala | 35 --- .../StreamingStatisticsResource.scala | 64 ------ 23 files changed, 349 insertions(+), 1000 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala rename core/src/main/scala/org/apache/spark/status/api/v1/{OneStageResource.scala => StagesResource.scala} (77%) delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala deleted file mode 100644 index 5522f4cebd773..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllExecutorListResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllExecutorListResource(ui: SparkUI) { - - @GET - def executorList(): Seq[ExecutorSummary] = ui.store.executorList(false) - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala deleted file mode 100644 index b4fa3e633f6c1..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.{Arrays, Date, List => JList} -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.JobUIData - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllJobsResource(ui: SparkUI) { - - @GET - def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { - ui.store.jobsList(statuses) - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala deleted file mode 100644 index 2189e1da91841..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.storage.{RDDInfo, StorageStatus, StorageUtils} -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllRDDResource(ui: SparkUI) { - - @GET - def rddList(): Seq[RDDStorageInfo] = ui.store.rddList() - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala deleted file mode 100644 index e1c91cb527a51..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.{List => JList} -import javax.ws.rs.{GET, Produces, QueryParam} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllStagesResource(ui: SparkUI) { - - @GET - def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { - ui.store.stageList(statuses) - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 9d3833086172f..ed9bdc6e1e3c2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -44,189 +44,14 @@ import org.apache.spark.ui.SparkUI private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") - def getApplicationList(): ApplicationListResource = { - new ApplicationListResource(uiRoot) - } + def applicationList(): Class[ApplicationListResource] = classOf[ApplicationListResource] @Path("applications/{appId}") - def getApplication(): OneApplicationResource = { - new OneApplicationResource(uiRoot) - } - - @Path("applications/{appId}/{attemptId}/jobs") - def getJobs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllJobsResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllJobsResource(ui) - } - } - - @Path("applications/{appId}/jobs") - def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - withSparkUI(appId, None) { ui => - new AllJobsResource(ui) - } - } - - @Path("applications/{appId}/jobs/{jobId: \\d+}") - def getJob(@PathParam("appId") appId: String): OneJobResource = { - withSparkUI(appId, None) { ui => - new OneJobResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/jobs/{jobId: \\d+}") - def getJob( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneJobResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneJobResource(ui) - } - } - - @Path("applications/{appId}/executors") - def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - withSparkUI(appId, None) { ui => - new ExecutorListResource(ui) - } - } - - @Path("applications/{appId}/allexecutors") - def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - withSparkUI(appId, None) { ui => - new AllExecutorListResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/executors") - def getExecutors( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ExecutorListResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ExecutorListResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/allexecutors") - def getAllExecutors( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllExecutorListResource(ui) - } - } - - @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource = { - withSparkUI(appId, None) { ui => - new AllStagesResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/stages") - def getStages( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllStagesResource(ui) - } - } - - @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource = { - withSparkUI(appId, None) { ui => - new OneStageResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/stages/{stageId: \\d+}") - def getStage( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneStageResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneStageResource(ui) - } - } - - @Path("applications/{appId}/storage/rdd") - def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - withSparkUI(appId, None) { ui => - new AllRDDResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/storage/rdd") - def getRdds( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllRDDResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new AllRDDResource(ui) - } - } - - @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") - def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - withSparkUI(appId, None) { ui => - new OneRDDResource(ui) - } - } - - @Path("applications/{appId}/{attemptId}/storage/rdd/{rddId: \\d+}") - def getRdd( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): OneRDDResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new OneRDDResource(ui) - } - } - - @Path("applications/{appId}/logs") - def getEventLogs( - @PathParam("appId") appId: String): EventLogDownloadResource = { - try { - // withSparkUI will throw NotFoundException if attemptId exists for this application. - // So we need to try again with attempt id "1". - withSparkUI(appId, None) { _ => - new EventLogDownloadResource(uiRoot, appId, None) - } - } catch { - case _: NotFoundException => - withSparkUI(appId, Some("1")) { _ => - new EventLogDownloadResource(uiRoot, appId, None) - } - } - } - - @Path("applications/{appId}/{attemptId}/logs") - def getEventLogs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - withSparkUI(appId, Some(attemptId)) { _ => - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) - } - } + def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] @Path("version") - def getVersion(): VersionResource = { - new VersionResource(uiRoot) - } - - @Path("applications/{appId}/environment") - def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = { - withSparkUI(appId, None) { ui => - new ApplicationEnvironmentResource(ui) - } - } + def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) - @Path("applications/{appId}/{attemptId}/environment") - def getEnvironment( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ApplicationEnvironmentResource(ui) - } - } } private[spark] object ApiRootResource { @@ -293,23 +118,29 @@ private[v1] trait ApiRequestContext { def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) +} - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { +/** + * Base class for resource handlers that use app-specific data. Abstracts away dealing with + * application and attempt IDs, and finding the app's UI. + */ +private[v1] trait BaseAppResource extends ApiRequestContext { + + @PathParam("appId") protected[this] var appId: String = _ + @PathParam("attemptId") protected[this] var attemptId: String = _ + + protected def withUI[T](fn: SparkUI => T): T = { try { - uiRoot.withSparkUI(appId, attemptId) { ui => + uiRoot.withSparkUI(appId, Option(attemptId)) { ui => val user = httpRequest.getRemoteUser() if (!ui.securityManager.checkUIViewPermissions(user)) { throw new ForbiddenException(raw"""user "$user" is not authorized""") } - f(ui) + fn(ui) } } catch { case _: NoSuchElementException => - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + val appKey = Option(attemptId).map(appId + "/" + _).getOrElse(appId) throw new NotFoundException(s"no such app: $appKey") } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala deleted file mode 100644 index e702f8aa2ef2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { - - @GET - def getEnvironmentInfo(): ApplicationEnvironmentInfo = { - ui.store.environmentInfo() - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index f039744e7f67f..91660a524ca93 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -23,7 +23,7 @@ import javax.ws.rs.core.MediaType import org.apache.spark.deploy.history.ApplicationHistoryInfo @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ApplicationListResource(uiRoot: UIRoot) { +private[v1] class ApplicationListResource extends ApiRequestContext { @GET def appList( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala deleted file mode 100644 index c84022ddfeef0..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.io.OutputStream -import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.{MediaType, Response, StreamingOutput} - -import scala.util.control.NonFatal - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging - -@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) -private[v1] class EventLogDownloadResource( - val uIRoot: UIRoot, - val appId: String, - val attemptId: Option[String]) extends Logging { - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) - - @GET - def getEventLogs(): Response = { - try { - val fileName = { - attemptId match { - case Some(id) => s"eventLogs-$appId-$id.zip" - case None => s"eventLogs-$appId.zip" - } - } - - val stream = new StreamingOutput { - override def write(output: OutputStream): Unit = { - val zipStream = new ZipOutputStream(output) - try { - uIRoot.writeEventLogs(appId, attemptId, zipStream) - } finally { - zipStream.close() - } - - } - } - - Response.ok(stream) - .header("Content-Disposition", s"attachment; filename=$fileName") - .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) - .build() - } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala deleted file mode 100644 index 975101c33c59c..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class ExecutorListResource(ui: SparkUI) { - - @GET - def executorList(): Seq[ExecutorSummary] = ui.store.executorList(true) - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 18c3e2f407360..bd4df07e7afc6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -16,16 +16,150 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType +import java.io.OutputStream +import java.util.{List => JList} +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.SparkUI @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneApplicationResource(uiRoot: UIRoot) { +private[v1] class AbstractApplicationResource extends BaseAppResource { + + @GET + @Path("jobs") + def jobsList(@QueryParam("status") statuses: JList[JobExecutionStatus]): Seq[JobData] = { + withUI(_.store.jobsList(statuses)) + } + + @GET + @Path("jobs/{jobId: \\d+}") + def oneJob(@PathParam("jobId") jobId: Int): JobData = withUI { ui => + try { + ui.store.job(jobId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException("unknown job: " + jobId) + } + } + + @GET + @Path("executors") + def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + + @GET + @Path("allexecutors") + def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) + + @Path("stages") + def stages(): Class[StagesResource] = classOf[StagesResource] + + @GET + @Path("storage/rdd") + def rddList(): Seq[RDDStorageInfo] = withUI(_.store.rddList()) + + @GET + @Path("storage/rdd/{rddId: \\d+}") + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = withUI { ui => + try { + ui.store.rdd(rddId) + } catch { + case _: NoSuchElementException => + throw new NotFoundException(s"no rdd found w/ id $rddId") + } + } + + @GET + @Path("environment") + def environmentInfo(): ApplicationEnvironmentInfo = withUI(_.store.environmentInfo()) + + @GET + @Path("logs") + @Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) + def getEventLogs(): Response = { + // Retrieve the UI for the application just to do access permission checks. For backwards + // compatibility, this code also tries with attemptId "1" if the UI without an attempt ID does + // not exist. + try { + withUI { _ => } + } catch { + case _: NotFoundException if attemptId == null => + attemptId = "1" + withUI { _ => } + attemptId = null + } + + try { + val fileName = if (attemptId != null) { + s"eventLogs-$appId-$attemptId.zip" + } else { + s"eventLogs-$appId.zip" + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uiRoot.writeEventLogs(appId, Option(attemptId), zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } + + /** + * This method needs to be last, otherwise it clashes with the paths for the above methods + * and causes JAX-RS to not find things. + */ + @Path("{attemptId}") + def applicationAttempt(): Class[OneApplicationAttemptResource] = { + if (attemptId != null) { + throw new NotFoundException(httpRequest.getRequestURI()) + } + classOf[OneApplicationAttemptResource] + } + +} + +private[v1] class OneApplicationResource extends AbstractApplicationResource { + + @GET + def getApp(): ApplicationInfo = { + val app = uiRoot.getApplicationInfo(appId) + app.getOrElse(throw new NotFoundException("unknown app: " + appId)) + } + +} + +private[v1] class OneApplicationAttemptResource extends AbstractApplicationResource { @GET - def getApp(@PathParam("appId") appId: String): ApplicationInfo = { - val apps = uiRoot.getApplicationInfo(appId) - apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) + def getAttempt(): ApplicationAttemptInfo = { + uiRoot.getApplicationInfo(appId) + .flatMap { app => + app.attempts.filter(_.attemptId == attemptId).headOption + } + .getOrElse { + throw new NotFoundException(s"unknown app $appId, attempt $attemptId") + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala deleted file mode 100644 index 3ee884e084c12..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.NoSuchElementException -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneJobResource(ui: SparkUI) { - - @GET - def oneJob(@PathParam("jobId") jobId: Int): JobData = { - try { - ui.store.job(jobId) - } catch { - case _: NoSuchElementException => - throw new NotFoundException("unknown job: " + jobId) - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala deleted file mode 100644 index ca9758cf0d109..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.util.NoSuchElementException -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.ui.SparkUI - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneRDDResource(ui: SparkUI) { - - @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { - try { - ui.store.rdd(rddId) - } catch { - case _: NoSuchElementException => - throw new NotFoundException(s"no rdd found w/ id $rddId") - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala similarity index 77% rename from core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala rename to core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 20dd73e916613..bd4dfe3c68885 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.{List => JList} import javax.ws.rs._ import javax.ws.rs.core.MediaType @@ -27,27 +28,34 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneStageResource(ui: SparkUI) { +private[v1] class StagesResource extends BaseAppResource { @GET - @Path("") + def stageList(@QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { + withUI(_.store.stageList(statuses)) + } + + @GET + @Path("{stageId: \\d+}") def stageData( @PathParam("stageId") stageId: Int, @QueryParam("details") @DefaultValue("true") details: Boolean): Seq[StageData] = { - val ret = ui.store.stageData(stageId, details = details) - if (ret.nonEmpty) { - ret - } else { - throw new NotFoundException(s"unknown stage: $stageId") + withUI { ui => + val ret = ui.store.stageData(stageId, details = details) + if (ret.nonEmpty) { + ret + } else { + throw new NotFoundException(s"unknown stage: $stageId") + } } } @GET - @Path("/{stageAttemptId: \\d+}") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}") def oneAttemptData( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, - @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = { + @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = withUI { ui => try { ui.store.stageAttempt(stageId, stageAttemptId, details = details) } catch { @@ -57,12 +65,12 @@ private[v1] class OneStageResource(ui: SparkUI) { } @GET - @Path("/{stageAttemptId: \\d+}/taskSummary") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskSummary") def taskSummary( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0.05,0.25,0.5,0.75,0.95") @QueryParam("quantiles") quantileString: String) - : TaskMetricDistributions = { + : TaskMetricDistributions = withUI { ui => val quantiles = quantileString.split(",").map { s => try { s.toDouble @@ -76,14 +84,14 @@ private[v1] class OneStageResource(ui: SparkUI) { } @GET - @Path("/{stageAttemptId: \\d+}/taskList") + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskList") def taskList( @PathParam("stageId") stageId: Int, @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int, @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { - ui.store.taskList(stageId, stageAttemptId, offset, length, sortBy) + withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala deleted file mode 100644 index 673da1ce36b57..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs._ -import javax.ws.rs.core.MediaType - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class VersionResource(ui: UIRoot) { - - @GET - def getVersionInfo(): VersionInfo = new VersionInfo( - org.apache.spark.SPARK_VERSION - ) - -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala deleted file mode 100644 index 3a51ae609303a..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.{ArrayList => JArrayList, Arrays => JArrays, Date, List => JList} -import javax.ws.rs.{GET, Produces, QueryParam} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.streaming.AllBatchesResource._ -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllBatchesResource(listener: StreamingJobProgressListener) { - - @GET - def batchesList(@QueryParam("status") statusParams: JList[BatchStatus]): Seq[BatchInfo] = { - batchInfoList(listener, statusParams).sortBy(- _.batchId) - } -} - -private[v1] object AllBatchesResource { - - def batchInfoList( - listener: StreamingJobProgressListener, - statusParams: JList[BatchStatus] = new JArrayList[BatchStatus]()): Seq[BatchInfo] = { - - listener.synchronized { - val statuses = - if (statusParams.isEmpty) JArrays.asList(BatchStatus.values(): _*) else statusParams - val statusToBatches = Seq( - BatchStatus.COMPLETED -> listener.retainedCompletedBatches, - BatchStatus.QUEUED -> listener.waitingBatches, - BatchStatus.PROCESSING -> listener.runningBatches - ) - - val batchInfos = for { - (status, batches) <- statusToBatches - batch <- batches if statuses.contains(status) - } yield { - val batchId = batch.batchTime.milliseconds - val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption - - new BatchInfo( - batchId = batchId, - batchTime = new Date(batchId), - status = status.toString, - batchDuration = listener.batchDuration, - inputSize = batch.numRecords, - schedulingDelay = batch.schedulingDelay, - processingTime = batch.processingDelay, - totalDelay = batch.totalDelay, - numActiveOutputOps = batch.numActiveOutputOp, - numCompletedOutputOps = batch.numCompletedOutputOp, - numFailedOutputOps = batch.numFailedOutputOp, - numTotalOutputOps = batch.outputOperations.size, - firstFailureReason = firstFailureReason - ) - } - - batchInfos - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala deleted file mode 100644 index 0eb649f0e1b72..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.status.api.v1.streaming.AllOutputOperationsResource._ -import org.apache.spark.streaming.Time -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllOutputOperationsResource(listener: StreamingJobProgressListener) { - - @GET - def operationsList(@PathParam("batchId") batchId: Long): Seq[OutputOperationInfo] = { - outputOperationInfoList(listener, batchId).sortBy(_.outputOpId) - } -} - -private[v1] object AllOutputOperationsResource { - - def outputOperationInfoList( - listener: StreamingJobProgressListener, - batchId: Long): Seq[OutputOperationInfo] = { - - listener.synchronized { - listener.getBatchUIData(Time(batchId)) match { - case Some(batch) => - for ((opId, op) <- batch.outputOperations) yield { - val jobIds = batch.outputOpIdSparkJobIdPairs - .filter(_.outputOpId == opId).map(_.sparkJobId).toSeq.sorted - - new OutputOperationInfo( - outputOpId = opId, - name = op.name, - description = op.description, - startTime = op.startTime.map(new Date(_)), - endTime = op.endTime.map(new Date(_)), - duration = op.duration, - failureReason = op.failureReason, - jobIds = jobIds - ) - } - case None => throw new NotFoundException("unknown batch: " + batchId) - } - }.toSeq - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala deleted file mode 100644 index 5a276a9236a0f..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.streaming.AllReceiversResource._ -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class AllReceiversResource(listener: StreamingJobProgressListener) { - - @GET - def receiversList(): Seq[ReceiverInfo] = { - receiverInfoList(listener).sortBy(_.streamId) - } -} - -private[v1] object AllReceiversResource { - - def receiverInfoList(listener: StreamingJobProgressListener): Seq[ReceiverInfo] = { - listener.synchronized { - listener.receivedRecordRateWithBatchTime.map { case (streamId, eventRates) => - - val receiverInfo = listener.receiverInfo(streamId) - val streamName = receiverInfo.map(_.name) - .orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") - val avgEventRate = - if (eventRates.isEmpty) None else Some(eventRates.map(_._2).sum / eventRates.size) - - val (errorTime, errorMessage, error) = receiverInfo match { - case None => (None, None, None) - case Some(info) => - val someTime = - if (info.lastErrorTime >= 0) Some(new Date(info.lastErrorTime)) else None - val someMessage = - if (info.lastErrorMessage.length > 0) Some(info.lastErrorMessage) else None - val someError = - if (info.lastError.length > 0) Some(info.lastError) else None - - (someTime, someMessage, someError) - } - - new ReceiverInfo( - streamId = streamId, - streamName = streamName, - isActive = receiverInfo.map(_.active), - executorId = receiverInfo.map(_.executorId), - executorHost = receiverInfo.map(_.location), - lastErrorTime = errorTime, - lastErrorMessage = errorMessage, - lastError = error, - avgEventRate = avgEventRate, - eventRates = eventRates - ) - }.toSeq - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala index aea75d5a9c8d0..07d8164e1d2c0 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -19,24 +19,39 @@ package org.apache.spark.status.api.v1.streaming import javax.ws.rs.{Path, PathParam} -import org.apache.spark.status.api.v1.ApiRequestContext +import org.apache.spark.status.api.v1._ +import org.apache.spark.streaming.ui.StreamingJobProgressListener @Path("/v1") private[v1] class ApiStreamingApp extends ApiRequestContext { @Path("applications/{appId}/streaming") - def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { - withSparkUI(appId, None) { ui => - new ApiStreamingRootResource(ui) - } + def getStreamingRoot(@PathParam("appId") appId: String): Class[ApiStreamingRootResource] = { + classOf[ApiStreamingRootResource] } @Path("applications/{appId}/{attemptId}/streaming") def getStreamingRoot( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { - withSparkUI(appId, Some(attemptId)) { ui => - new ApiStreamingRootResource(ui) + @PathParam("attemptId") attemptId: String): Class[ApiStreamingRootResource] = { + classOf[ApiStreamingRootResource] + } +} + +/** + * Base class for streaming API handlers, provides easy access to the streaming listener that + * holds the app's information. + */ +private[v1] trait BaseStreamingAppResource extends BaseAppResource { + + protected def withListener[T](fn: StreamingJobProgressListener => T): T = withUI { ui => + val listener = ui.getStreamingJobProgressListener match { + case Some(listener) => listener.asInstanceOf[StreamingJobProgressListener] + case None => throw new NotFoundException("no streaming listener attached to " + ui.getAppName) + } + listener.synchronized { + fn(listener) } } + } diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala index 1ccd586c848bd..a2571b910f615 100644 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala @@ -17,58 +17,180 @@ package org.apache.spark.status.api.v1.streaming -import javax.ws.rs.Path +import java.util.{Arrays => JArrays, Collections, Date, List => JList} +import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs.core.MediaType import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ import org.apache.spark.ui.SparkUI -private[v1] class ApiStreamingRootResource(ui: SparkUI) { - - import org.apache.spark.status.api.v1.streaming.ApiStreamingRootResource._ +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApiStreamingRootResource extends BaseStreamingAppResource { + @GET @Path("statistics") - def getStreamingStatistics(): StreamingStatisticsResource = { - new StreamingStatisticsResource(getListener(ui)) + def streamingStatistics(): StreamingStatistics = withListener { listener => + val batches = listener.retainedBatches + val avgInputRate = avgRate(batches.map(_.numRecords * 1000.0 / listener.batchDuration)) + val avgSchedulingDelay = avgTime(batches.flatMap(_.schedulingDelay)) + val avgProcessingTime = avgTime(batches.flatMap(_.processingDelay)) + val avgTotalDelay = avgTime(batches.flatMap(_.totalDelay)) + + new StreamingStatistics( + startTime = new Date(listener.startTime), + batchDuration = listener.batchDuration, + numReceivers = listener.numReceivers, + numActiveReceivers = listener.numActiveReceivers, + numInactiveReceivers = listener.numInactiveReceivers, + numTotalCompletedBatches = listener.numTotalCompletedBatches, + numRetainedCompletedBatches = listener.retainedCompletedBatches.size, + numActiveBatches = listener.numUnprocessedBatches, + numProcessedRecords = listener.numTotalProcessedRecords, + numReceivedRecords = listener.numTotalReceivedRecords, + avgInputRate = avgInputRate, + avgSchedulingDelay = avgSchedulingDelay, + avgProcessingTime = avgProcessingTime, + avgTotalDelay = avgTotalDelay + ) } + @GET @Path("receivers") - def getReceivers(): AllReceiversResource = { - new AllReceiversResource(getListener(ui)) + def receiversList(): Seq[ReceiverInfo] = withListener { listener => + listener.receivedRecordRateWithBatchTime.map { case (streamId, eventRates) => + val receiverInfo = listener.receiverInfo(streamId) + val streamName = receiverInfo.map(_.name) + .orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") + val avgEventRate = + if (eventRates.isEmpty) None else Some(eventRates.map(_._2).sum / eventRates.size) + + val (errorTime, errorMessage, error) = receiverInfo match { + case None => (None, None, None) + case Some(info) => + val someTime = + if (info.lastErrorTime >= 0) Some(new Date(info.lastErrorTime)) else None + val someMessage = + if (info.lastErrorMessage.length > 0) Some(info.lastErrorMessage) else None + val someError = + if (info.lastError.length > 0) Some(info.lastError) else None + + (someTime, someMessage, someError) + } + + new ReceiverInfo( + streamId = streamId, + streamName = streamName, + isActive = receiverInfo.map(_.active), + executorId = receiverInfo.map(_.executorId), + executorHost = receiverInfo.map(_.location), + lastErrorTime = errorTime, + lastErrorMessage = errorMessage, + lastError = error, + avgEventRate = avgEventRate, + eventRates = eventRates + ) + }.toSeq.sortBy(_.streamId) } + @GET @Path("receivers/{streamId: \\d+}") - def getReceiver(): OneReceiverResource = { - new OneReceiverResource(getListener(ui)) + def oneReceiver(@PathParam("streamId") streamId: Int): ReceiverInfo = { + receiversList().find { _.streamId == streamId }.getOrElse( + throw new NotFoundException("unknown receiver: " + streamId)) } + @GET @Path("batches") - def getBatches(): AllBatchesResource = { - new AllBatchesResource(getListener(ui)) + def batchesList(@QueryParam("status") statusParams: JList[BatchStatus]): Seq[BatchInfo] = { + withListener { listener => + val statuses = + if (statusParams.isEmpty) JArrays.asList(BatchStatus.values(): _*) else statusParams + val statusToBatches = Seq( + BatchStatus.COMPLETED -> listener.retainedCompletedBatches, + BatchStatus.QUEUED -> listener.waitingBatches, + BatchStatus.PROCESSING -> listener.runningBatches + ) + + val batchInfos = for { + (status, batches) <- statusToBatches + batch <- batches if statuses.contains(status) + } yield { + val batchId = batch.batchTime.milliseconds + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + + new BatchInfo( + batchId = batchId, + batchTime = new Date(batchId), + status = status.toString, + batchDuration = listener.batchDuration, + inputSize = batch.numRecords, + schedulingDelay = batch.schedulingDelay, + processingTime = batch.processingDelay, + totalDelay = batch.totalDelay, + numActiveOutputOps = batch.numActiveOutputOp, + numCompletedOutputOps = batch.numCompletedOutputOp, + numFailedOutputOps = batch.numFailedOutputOp, + numTotalOutputOps = batch.outputOperations.size, + firstFailureReason = firstFailureReason + ) + } + + batchInfos.sortBy(- _.batchId) + } } + @GET @Path("batches/{batchId: \\d+}") - def getBatch(): OneBatchResource = { - new OneBatchResource(getListener(ui)) + def oneBatch(@PathParam("batchId") batchId: Long): BatchInfo = { + batchesList(Collections.emptyList()).find { _.batchId == batchId }.getOrElse( + throw new NotFoundException("unknown batch: " + batchId)) } + @GET @Path("batches/{batchId: \\d+}/operations") - def getOutputOperations(): AllOutputOperationsResource = { - new AllOutputOperationsResource(getListener(ui)) + def operationsList(@PathParam("batchId") batchId: Long): Seq[OutputOperationInfo] = { + withListener { listener => + val ops = listener.getBatchUIData(Time(batchId)) match { + case Some(batch) => + for ((opId, op) <- batch.outputOperations) yield { + val jobIds = batch.outputOpIdSparkJobIdPairs + .filter(_.outputOpId == opId).map(_.sparkJobId).toSeq.sorted + + new OutputOperationInfo( + outputOpId = opId, + name = op.name, + description = op.description, + startTime = op.startTime.map(new Date(_)), + endTime = op.endTime.map(new Date(_)), + duration = op.duration, + failureReason = op.failureReason, + jobIds = jobIds + ) + } + case None => throw new NotFoundException("unknown batch: " + batchId) + } + ops.toSeq + } } + @GET @Path("batches/{batchId: \\d+}/operations/{outputOpId: \\d+}") - def getOutputOperation(): OneOutputOperationResource = { - new OneOutputOperationResource(getListener(ui)) + def oneOperation( + @PathParam("batchId") batchId: Long, + @PathParam("outputOpId") opId: OutputOpId): OutputOperationInfo = { + operationsList(batchId).find { _.outputOpId == opId }.getOrElse( + throw new NotFoundException("unknown output operation: " + opId)) } -} + private def avgRate(data: Seq[Double]): Option[Double] = { + if (data.isEmpty) None else Some(data.sum / data.size) + } -private[v1] object ApiStreamingRootResource { - def getListener(ui: SparkUI): StreamingJobProgressListener = { - ui.getStreamingJobProgressListener match { - case Some(listener) => listener.asInstanceOf[StreamingJobProgressListener] - case None => throw new NotFoundException("no streaming listener attached to " + ui.getAppName) - } + private def avgTime(data: Seq[Long]): Option[Long] = { + if (data.isEmpty) None else Some(data.sum / data.size) } + } diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala deleted file mode 100644 index d3c689c790cfc..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneBatchResource(listener: StreamingJobProgressListener) { - - @GET - def oneBatch(@PathParam("batchId") batchId: Long): BatchInfo = { - val someBatch = AllBatchesResource.batchInfoList(listener) - .find { _.batchId == batchId } - someBatch.getOrElse(throw new NotFoundException("unknown batch: " + batchId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala deleted file mode 100644 index aabcdb29b0d4c..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener -import org.apache.spark.streaming.ui.StreamingJobProgressListener._ - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneOutputOperationResource(listener: StreamingJobProgressListener) { - - @GET - def oneOperation( - @PathParam("batchId") batchId: Long, - @PathParam("outputOpId") opId: OutputOpId): OutputOperationInfo = { - - val someOutputOp = AllOutputOperationsResource.outputOperationInfoList(listener, batchId) - .find { _.outputOpId == opId } - someOutputOp.getOrElse(throw new NotFoundException("unknown output operation: " + opId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala deleted file mode 100644 index c0cc99da3a9c7..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import javax.ws.rs.{GET, PathParam, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.status.api.v1.NotFoundException -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class OneReceiverResource(listener: StreamingJobProgressListener) { - - @GET - def oneReceiver(@PathParam("streamId") streamId: Int): ReceiverInfo = { - val someReceiver = AllReceiversResource.receiverInfoList(listener) - .find { _.streamId == streamId } - someReceiver.getOrElse(throw new NotFoundException("unknown receiver: " + streamId)) - } -} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala deleted file mode 100644 index 6cff87be59ca8..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status.api.v1.streaming - -import java.util.Date -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.MediaType - -import org.apache.spark.streaming.ui.StreamingJobProgressListener - -@Produces(Array(MediaType.APPLICATION_JSON)) -private[v1] class StreamingStatisticsResource(listener: StreamingJobProgressListener) { - - @GET - def streamingStatistics(): StreamingStatistics = { - listener.synchronized { - val batches = listener.retainedBatches - val avgInputRate = avgRate(batches.map(_.numRecords * 1000.0 / listener.batchDuration)) - val avgSchedulingDelay = avgTime(batches.flatMap(_.schedulingDelay)) - val avgProcessingTime = avgTime(batches.flatMap(_.processingDelay)) - val avgTotalDelay = avgTime(batches.flatMap(_.totalDelay)) - - new StreamingStatistics( - startTime = new Date(listener.startTime), - batchDuration = listener.batchDuration, - numReceivers = listener.numReceivers, - numActiveReceivers = listener.numActiveReceivers, - numInactiveReceivers = listener.numInactiveReceivers, - numTotalCompletedBatches = listener.numTotalCompletedBatches, - numRetainedCompletedBatches = listener.retainedCompletedBatches.size, - numActiveBatches = listener.numUnprocessedBatches, - numProcessedRecords = listener.numTotalProcessedRecords, - numReceivedRecords = listener.numTotalReceivedRecords, - avgInputRate = avgInputRate, - avgSchedulingDelay = avgSchedulingDelay, - avgProcessingTime = avgProcessingTime, - avgTotalDelay = avgTotalDelay - ) - } - } - - private def avgRate(data: Seq[Double]): Option[Double] = { - if (data.isEmpty) None else Some(data.sum / data.size) - } - - private def avgTime(data: Seq[Long]): Option[Long] = { - if (data.isEmpty) None else Some(data.sum / data.size) - } -} From 2014e7a789d36e376ca62b1e24636d79c1b19745 Mon Sep 17 00:00:00 2001 From: osatici Date: Wed, 15 Nov 2017 14:08:51 -0800 Subject: [PATCH 254/712] [SPARK-22479][SQL] Exclude credentials from SaveintoDataSourceCommand.simpleString ## What changes were proposed in this pull request? Do not include jdbc properties which may contain credentials in logging a logical plan with `SaveIntoDataSourceCommand` in it. ## How was this patch tested? building locally and trying to reproduce (per the steps in https://issues.apache.org/jira/browse/SPARK-22479): ``` == Parsed Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Analyzed Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Optimized Logical Plan == SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) == Physical Plan == Execute SaveIntoDataSourceCommand +- SaveIntoDataSourceCommand org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider570127fa, Map(dbtable -> test20, driver -> org.postgresql.Driver, url -> *********(redacted), password -> *********(redacted)), ErrorIfExists +- Range (0, 100, step=1, splits=Some(8)) ``` Author: osatici Closes #19708 from onursatici/os/redact-jdbc-creds. --- .../spark/internal/config/package.scala | 2 +- .../SaveIntoDataSourceCommand.scala | 7 +++ .../SaveIntoDataSourceCommandSuite.scala | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 57e2da8353d6d..84315f55a59ad 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -307,7 +307,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password".r) + .createWithDefault("(?i)secret|password|url|user|username".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 96c84eab1c894..568e953a5db66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider +import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -46,4 +48,9 @@ case class SaveIntoDataSourceCommand( Seq.empty[Row] } + + override def simpleString: String = { + val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala new file mode 100644 index 0000000000000..4b3ca8e60cab6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.test.SharedSQLContext + +class SaveIntoDataSourceCommandSuite extends SharedSQLContext { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.regex", "(?i)password|url") + + test("simpleString is redacted") { + val URL = "connection.url" + val PASS = "123" + val DRIVER = "mydriver" + + val dataSource = DataSource( + sparkSession = spark, + className = "jdbc", + partitionColumns = Nil, + options = Map("password" -> PASS, "url" -> URL, "driver" -> DRIVER)) + + val logicalPlanString = dataSource + .planForWriting(SaveMode.ErrorIfExists, spark.range(1).logicalPlan) + .treeString(true) + + assert(!logicalPlanString.contains(URL)) + assert(!logicalPlanString.contains(PASS)) + assert(logicalPlanString.contains(DRIVER)) + } +} From 1e82335413bc2384073ead0d6d581c862036d0f5 Mon Sep 17 00:00:00 2001 From: ArtRand Date: Wed, 15 Nov 2017 15:53:05 -0800 Subject: [PATCH 255/712] [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos ## What changes were proposed in this pull request? tl;dr: Add a class, `MesosHadoopDelegationTokenManager` that updates delegation tokens on a schedule on the behalf of Spark Drivers. Broadcast renewed credentials to the executors. ## The problem We recently added Kerberos support to Mesos-based Spark jobs as well as Secrets support to the Mesos Dispatcher (SPARK-16742, SPARK-20812, respectively). However the delegation tokens have a defined expiration. This poses a problem for long running Spark jobs (e.g. Spark Streaming applications). YARN has a solution for this where a thread is scheduled to renew the tokens they reach 75% of their way to expiration. It then writes the tokens to HDFS for the executors to find (uses a monotonically increasing suffix). ## This solution We replace the current method in `CoarseGrainedSchedulerBackend` which used to discard the token renewal time with a protected method `fetchHadoopDelegationTokens`. Now the individual cluster backends are responsible for overriding this method to fetch and manage token renewal. The delegation tokens themselves, are still part of the `CoarseGrainedSchedulerBackend` as before. In the case of Mesos renewed Credentials are broadcasted to the executors. This maintains all transfer of Credentials within Spark (as opposed to Spark-to-HDFS). It also does not require any writing of Credentials to disk. It also does not require any GC of old files. ## How was this patch tested? Manually against a Kerberized HDFS cluster. Thank you for the reviews. Author: ArtRand Closes #19272 from ArtRand/spark-21842-450-kerberos-ticket-renewal. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 28 +++- .../HadoopDelegationTokenManager.scala | 3 + .../CoarseGrainedExecutorBackend.scala | 9 +- .../cluster/CoarseGrainedClusterMessage.scala | 3 + .../CoarseGrainedSchedulerBackend.scala | 30 +--- .../cluster/mesos/MesosClusterManager.scala | 2 +- .../MesosCoarseGrainedSchedulerBackend.scala | 19 ++- .../MesosHadoopDelegationTokenManager.scala | 157 ++++++++++++++++++ .../yarn/security/AMCredentialRenewer.scala | 21 +-- 9 files changed, 228 insertions(+), 44 deletions(-) create mode 100644 resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 1fa10ab943f34..17c7319b40f24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -140,12 +140,23 @@ class SparkHadoopUtil extends Logging { if (!new File(keytabFilename).exists()) { throw new SparkException(s"Keytab file: ${keytabFilename} does not exist") } else { - logInfo("Attempting to login to Kerberos" + - s" using principal: ${principalName} and keytab: ${keytabFilename}") + logInfo("Attempting to login to Kerberos " + + s"using principal: ${principalName} and keytab: ${keytabFilename}") UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } } + /** + * Add or overwrite current user's credentials with serialized delegation tokens, + * also confirms correct hadoop configuration is set. + */ + private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { + UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) + val creds = deserialize(tokens) + logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + addCurrentUserCredentials(creds) + } + /** * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will @@ -462,6 +473,19 @@ object SparkHadoopUtil { } } + /** + * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date + * when a given fraction of the duration until the expiration date has passed. + * Formula: current time + (fraction * (time until expiration)) + * @param expirationDate Drop-dead expiration date + * @param fraction fraction of the time until expiration return + * @return Date when the fraction of the time until expiration has passed + */ + private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + val ct = System.currentTimeMillis + (ct + (fraction * (expirationDate - ct))).toLong + } + /** * Returns a Configuration object with Spark configuration applied on top. Unlike * the instance method, this will always return a Configuration instance, and not a diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 483d0deec8070..116a686fe1480 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -109,6 +109,8 @@ private[spark] class HadoopDelegationTokenManager( * Writes delegation tokens to creds. Delegation tokens are fetched from all registered * providers. * + * @param hadoopConf hadoop Configuration + * @param creds Credentials that will be updated in place (overwritten) * @return Time after which the fetched delegation tokens should be renewed. */ def obtainDelegationTokens( @@ -125,3 +127,4 @@ private[spark] class HadoopDelegationTokenManager( }.foldLeft(Long.MaxValue)(math.min) } } + diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index d27362ae85bea..acefc9d2436d0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -123,6 +123,10 @@ private[spark] class CoarseGrainedExecutorBackend( executor.stop() } }.start() + + case UpdateDelegationTokens(tokenBytes) => + logInfo(s"Received tokens of ${tokenBytes.length} bytes") + SparkHadoopUtil.get.addDelegationTokens(tokenBytes, env.conf) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -219,9 +223,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.startCredentialUpdater(driverConf) } - cfg.hadoopDelegationCreds.foreach { hadoopCreds => - val creds = SparkHadoopUtil.get.deserialize(hadoopCreds) - SparkHadoopUtil.get.addCurrentUserCredentials(creds) + cfg.hadoopDelegationCreds.foreach { tokens => + SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } val env = SparkEnv.createExecutorEnv( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 5d65731dfc30e..e8b7fc0ef100a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -54,6 +54,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse + case class UpdateDelegationTokens(tokens: Array[Byte]) + extends CoarseGrainedClusterMessage + // Executors to driver case class RegisterExecutor( executorId: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 424e43b25c77a..22d9c4cf81c55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -24,11 +24,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future -import org.apache.hadoop.security.UserGroupInformation - import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -99,12 +95,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 - // hadoop token manager used by some sub-classes (e.g. Mesos) - def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = None - - // Hadoop delegation tokens to be sent to the executors. - val hadoopDelegationCreds: Option[Array[Byte]] = getHadoopDelegationCreds() - class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -159,6 +149,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.getExecutorsAliveOnHost(host).foreach { exec => killExecutors(exec.toSeq, replace = true, force = true) } + + case UpdateDelegationTokens(newDelegationTokens) => + executorDataMap.values.foreach { ed => + ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -236,7 +231,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val reply = SparkAppConfig( sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey(), - hadoopDelegationCreds) + fetchHadoopDelegationTokens()) context.reply(reply) } @@ -686,18 +681,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp true } - protected def getHadoopDelegationCreds(): Option[Array[Byte]] = { - if (UserGroupInformation.isSecurityEnabled && hadoopDelegationTokenManager.isDefined) { - hadoopDelegationTokenManager.map { manager => - val creds = UserGroupInformation.getCurrentUser.getCredentials - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - manager.obtainDelegationTokens(hadoopConf, creds) - SparkHadoopUtil.get.serialize(creds) - } - } else { - None - } - } + protected def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { None } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index 911a0857917ef..da71f8f9e407c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 104ed01d293ce..c392061fdb358 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -22,15 +22,16 @@ import java.util.{Collections, List => JList} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import java.util.concurrent.locks.ReentrantLock -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.SchedulerDriver import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future +import org.apache.hadoop.security.UserGroupInformation +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver + import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ -import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf @@ -58,8 +59,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with org.apache.mesos.Scheduler with MesosSchedulerUtils { - override def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = - Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration)) + private lazy val hadoopDelegationTokenManager: MesosHadoopDelegationTokenManager = + new MesosHadoopDelegationTokenManager(conf, sc.hadoopConfiguration, driverEndpoint) // Blacklist a slave after this many failures private val MAX_SLAVE_FAILURES = 2 @@ -772,6 +773,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( offer.getHostname } } + + override def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { + if (UserGroupInformation.isSecurityEnabled) { + Some(hadoopDelegationTokenManager.getTokens()) + } else { + None + } + } } private class Slave(val hostname: String) { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..325dc179d63ea --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.util.ThreadUtils + + +/** + * The MesosHadoopDelegationTokenManager fetches and updates Hadoop delegation tokens on the behalf + * of the MesosCoarseGrainedSchedulerBackend. It is modeled after the YARN AMCredentialRenewer, + * and similarly will renew the Credentials when 75% of the renewal interval has passed. + * The principal difference is that instead of writing the new credentials to HDFS and + * incrementing the timestamp of the file, the new credentials (called Tokens when they are + * serialized) are broadcast to all running executors. On the executor side, when new Tokens are + * received they overwrite the current credentials. + */ +private[spark] class MesosHadoopDelegationTokenManager( + conf: SparkConf, + hadoopConfig: Configuration, + driverEndpoint: RpcEndpointRef) + extends Logging { + + require(driverEndpoint != null, "DriverEndpoint is not initialized") + + private val credentialRenewerThread: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread") + + private val tokenManager: HadoopDelegationTokenManager = + new HadoopDelegationTokenManager(conf, hadoopConfig) + + private val principal: String = conf.get(config.PRINCIPAL).orNull + + private var (tokens: Array[Byte], timeOfNextRenewal: Long) = { + try { + val creds = UserGroupInformation.getCurrentUser.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) + logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") + (SparkHadoopUtil.get.serialize(creds), rt) + } catch { + case e: Exception => + logError(s"Failed to fetch Hadoop delegation tokens $e") + throw e + } + } + + private val keytabFile: Option[String] = conf.get(config.KEYTAB) + + scheduleTokenRenewal() + + private def scheduleTokenRenewal(): Unit = { + if (keytabFile.isDefined) { + require(principal != null, "Principal is required for Keytab-based authentication") + logInfo(s"Using keytab: ${keytabFile.get} and principal $principal") + } else { + logInfo("Using ticket cache for Kerberos authentication, no token renewal.") + return + } + + def scheduleRenewal(runnable: Runnable): Unit = { + val remainingTime = timeOfNextRenewal - System.currentTimeMillis() + if (remainingTime <= 0) { + logInfo("Credentials have expired, creating new ones now.") + runnable.run() + } else { + logInfo(s"Scheduling login from keytab in $remainingTime millis.") + credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + } + } + + val credentialRenewerRunnable = + new Runnable { + override def run(): Unit = { + try { + getNewDelegationTokens() + broadcastDelegationTokens(tokens) + } catch { + case e: Exception => + // Log the error and try to write new tokens back in an hour + logWarning("Couldn't broadcast tokens, trying again in an hour", e) + credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + return + } + scheduleRenewal(this) + } + } + scheduleRenewal(credentialRenewerRunnable) + } + + private def getNewDelegationTokens(): Unit = { + logInfo(s"Attempting to login to KDC with principal ${principal}") + // Get new delegation tokens by logging in with a new UGI inspired by AMCredentialRenewer.scala + // Don't protect against keytabFile being empty because it's guarded above. + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytabFile.get) + logInfo("Successfully logged into KDC") + val tempCreds = ugi.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + val nextRenewalTime = ugi.doAs(new PrivilegedExceptionAction[Long] { + override def run(): Long = { + tokenManager.obtainDelegationTokens(hadoopConf, tempCreds) + } + }) + + val currTime = System.currentTimeMillis() + timeOfNextRenewal = if (nextRenewalTime <= currTime) { + logWarning(s"Next credential renewal time ($nextRenewalTime) is earlier than " + + s"current time ($currTime), which is unexpected, please check your credential renewal " + + "related configurations in the target services.") + currTime + } else { + SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + } + logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") + + // Add the temp credentials back to the original ones. + UserGroupInformation.getCurrentUser.addCredentials(tempCreds) + // update tokens for late or dynamically added executors + tokens = SparkHadoopUtil.get.serialize(tempCreds) + } + + private def broadcastDelegationTokens(tokens: Array[Byte]) = { + logInfo("Sending new tokens to all executors") + driverEndpoint.send(UpdateDelegationTokens(tokens)) + } + + def getTokens(): Array[Byte] = { + tokens + } +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 68a2e9e70a78b..6134757a82fdc 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -25,6 +25,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging @@ -58,9 +59,8 @@ private[yarn] class AMCredentialRenewer( private var lastCredentialsFileSuffix = 0 - private val credentialRenewer = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) + private val credentialRenewerThread: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") private val hadoopUtil = YarnSparkHadoopUtil.get @@ -70,7 +70,7 @@ private[yarn] class AMCredentialRenewer( private val freshHadoopConf = hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) - @volatile private var timeOfNextRenewal = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -95,7 +95,7 @@ private[yarn] class AMCredentialRenewer( runnable.run() } else { logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewer.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -111,7 +111,7 @@ private[yarn] class AMCredentialRenewer( // Log the error and try to write new tokens back in an hour logWarning("Failed to write out new credentials to HDFS, will try again in an " + "hour! If this happens too often tasks will fail.", e) - credentialRenewer.schedule(this, 1, TimeUnit.HOURS) + credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) return } scheduleRenewal(this) @@ -195,8 +195,9 @@ private[yarn] class AMCredentialRenewer( } else { // Next valid renewal time is about 75% of credential renewal time, and update time is // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = ((nearestNextRenewalTime - currTime) * 0.75 + currTime).toLong - ((nearestNextRenewalTime - currTime) * 0.8 + currTime).toLong + timeOfNextRenewal = + SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) + SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) } // Add the temp credentials back to the original ones. @@ -232,6 +233,6 @@ private[yarn] class AMCredentialRenewer( } def stop(): Unit = { - credentialRenewer.shutdown() + credentialRenewerThread.shutdown() } } From 03f2b7bff7e537ec747b41ad22e448e1c141f0dd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 16 Nov 2017 14:22:25 +0900 Subject: [PATCH 256/712] [SPARK-22535][PYSPARK] Sleep before killing the python worker in PythonRunner.MonitorThread ## What changes were proposed in this pull request? `PythonRunner.MonitorThread` should give the task a little time to finish before forcibly killing the python worker. This will reduce the chance of the race condition a lot. I also improved the log a bit to find out the task to blame when it's stuck. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19762 from zsxwing/SPARK-22535. --- .../spark/api/python/PythonRunner.scala | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d417303bb147d..9989f68f8508c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -337,6 +337,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) extends Thread(s"Worker Monitor for $pythonExec") { + /** How long to wait before killing the python worker if a task cannot be interrupted. */ + private val taskKillTimeout = env.conf.getTimeAsMs("spark.python.task.killTimeout", "2s") + setDaemon(true) override def run() { @@ -346,12 +349,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( Thread.sleep(2000) } if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) + Thread.sleep(taskKillTimeout) + if (!context.isCompleted) { + try { + // Mimic the task name used in `Executor` to help the user find out the task to blame. + val taskName = s"${context.partitionId}.${context.taskAttemptId} " + + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" + logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } } } } From ed885e7a6504c439ffb6730e6963efbd050d43dd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Nov 2017 17:56:21 +0100 Subject: [PATCH 257/712] [SPARK-22499][SQL] Fix 64KB JVM bytecode limit problem with least and greatest ## What changes were proposed in this pull request? This PR changes `least` and `greatest` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved two cases: * `least` with a lot of argument * `greatest` with a lot of argument ## How was this patch tested? Added a new test case into `ArithmeticExpressionsSuite` Author: Kazuaki Ishizaki Closes #19729 from kiszk/SPARK-22499. --- .../sql/catalyst/expressions/arithmetic.scala | 24 +++++++++---------- .../ArithmeticExpressionSuite.scala | 9 +++++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7559852a2ac45..72d5889d2f202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val first = evalChildren(0) - val rest = evalChildren.drop(1) + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -614,11 +614,11 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) ev.copy(code = s""" - ${first.code} - boolean ${ev.isNull} = ${first.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")}""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + $codes""") } } @@ -668,8 +668,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val first = evalChildren(0) - val rest = evalChildren.drop(1) + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -680,10 +680,10 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) ev.copy(code = s""" - ${first.code} - boolean ${ev.isNull} = ${first.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")}""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + $codes""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 031053727f08e..fb759eba6a9e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -334,4 +334,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } + + test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { + val N = 3000 + val strings = (1 to N).map(x => "s" * x) + val inputsExpr = strings.map(Literal.create(_, StringType)) + + checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) + checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) + } } From 4e7f07e2550fa995cc37406173a937033135cf3b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 16 Nov 2017 18:19:13 +0100 Subject: [PATCH 258/712] [SPARK-22494][SQL] Fix 64KB limit exception with Coalesce and AtleastNNonNulls ## What changes were proposed in this pull request? Both `Coalesce` and `AtLeastNNonNulls` can cause the 64KB limit exception when used with a lot of arguments and/or complex expressions. This PR splits their expressions in order to avoid the issue. ## How was this patch tested? Added UTs Author: Marco Gaido Author: Marco Gaido Closes #19720 from mgaido91/SPARK-22494. --- .../expressions/nullExpressions.scala | 42 ++++++++++++++----- .../expressions/NullExpressionsSuite.scala | 10 +++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 62786e13bda2c..4aeab2c3ad0a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,14 +72,10 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val first = children(0) - val rest = children.drop(1) - val firstEval = first.genCode(ctx) - ev.copy(code = s""" - ${firstEval.code} - boolean ${ev.isNull} = ${firstEval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" + - rest.map { e => + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + + val evals = children.map { e => val eval = e.genCode(ctx) s""" if (${ev.isNull}) { @@ -90,7 +86,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } """ - }.mkString("\n")) + } + + ev.copy(code = s""" + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""") } } @@ -357,7 +358,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") - val code = children.map { e => + val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => @@ -379,7 +380,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } """ } - }.mkString("\n") + } + + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + evals.mkString("\n") + } else { + ctx.splitExpressions(evals, "atLeastNNonNulls", + ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, + returnType = "int", + makeSplitFunction = { body => + s""" + $body + return $nonnull; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n") + } + ) + } + ev.copy(code = s""" int $nonnull = 0; $code diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 394c0a091e390..40ef7770da33f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -149,4 +149,14 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } + + test("Coalesce should not throw 64kb exception") { + val inputs = (1 to 2500).map(x => Literal(s"x_$x")) + checkEvaluation(Coalesce(inputs), "x_1") + } + + test("AtLeastNNonNulls should not throw 64kb exception") { + val inputs = (1 to 4000).map(x => Literal(s"x_$x")) + checkEvaluation(AtLeastNNonNulls(1, inputs), true) + } } From 7f2e62ee6b9d1f32772a18d626fb9fd907aa7733 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Nov 2017 18:24:49 +0100 Subject: [PATCH 259/712] [SPARK-22501][SQL] Fix 64KB JVM bytecode limit problem with in ## What changes were proposed in this pull request? This PR changes `In` code generation to place generated code for expression for expressions for arguments into separated methods if these size could be large. ## How was this patch tested? Added new test cases into `PredicateSuite` Author: Kazuaki Ishizaki Closes #19733 from kiszk/SPARK-22501. --- .../sql/catalyst/expressions/predicates.scala | 20 ++++++++++++++----- .../catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 61df5e053a374..5d75c6004bfe4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -236,24 +236,34 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) + ctx.addMutableState("boolean", ev.value, "") + ctx.addMutableState("boolean", ev.isNull, "") + val valueArg = ctx.freshName("valueArg") val listCode = listGen.map(x => s""" if (!${ev.value}) { ${x.code} if (${x.isNull}) { ${ev.isNull} = true; - } else if (${ctx.genEqual(value.dataType, valueGen.value, x.value)}) { + } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { ${ev.isNull} = false; ${ev.value} = true; } } - """).mkString("\n") + """) + val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil + ctx.splitExpressions(listCode, "valueIn", args) + } else { + listCode.mkString("\n") + } ev.copy(code = s""" ${valueGen.code} - boolean ${ev.value} = false; - boolean ${ev.isNull} = ${valueGen.isNull}; + ${ev.value} = false; + ${ev.isNull} = ${valueGen.isNull}; if (!${ev.isNull}) { - $listCode + ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value}; + $listCodes } """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1438a88c19e0b..865092a659f26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -239,6 +239,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22501: In should not generate codes beyond 64KB") { + val N = 3000 + val sets = (1 to N).map(i => Literal(i.toDouble)) + checkEvaluation(In(Literal(1.0D), sets), true) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null From b9dcbe5e1ba81135a51b486240662674728dda84 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 Nov 2017 18:23:00 -0800 Subject: [PATCH 260/712] [SPARK-22542][SQL] remove unused features in ColumnarBatch ## What changes were proposed in this pull request? `ColumnarBatch` provides features to do fast filter and project in a columnar fashion, however this feature is never used by Spark, as Spark uses whole stage codegen and processes the data in a row fashion. This PR proposes to remove these unused features as we won't switch to columnar execution in the near future. Even we do, I think this part needs a proper redesign. This is also a step to make `ColumnVector` public, as we don't wanna expose these features to users. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19766 from cloud-fan/vector. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 - .../execution/vectorized/ColumnarArray.java | 2 +- .../execution/vectorized/ColumnarBatch.java | 78 +------------------ .../sql/execution/vectorized/ColumnarRow.java | 26 ------- .../arrow/ArrowConvertersSuite.scala | 2 +- .../parquet/ParquetReadBenchmark.scala | 24 ------ .../vectorized/ColumnarBatchSuite.scala | 52 ------------- 7 files changed, 6 insertions(+), 182 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ec947d7580282..71c086029cc5b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -69,10 +69,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields + 63)/ 64) * 8; } - public static int calculateFixedPortionByteSize(int numFields) { - return 8 * numFields + calculateBitSetWidthInBytes(numFields); - } - /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java index 5e88ce0321084..34bde3e14d378 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -41,7 +41,7 @@ public final class ColumnarArray extends ArrayData { // Reused staging buffer, used for loading from offheap. protected byte[] tmpByteArray = new byte[1]; - protected ColumnarArray(ColumnVector data) { + ColumnarArray(ColumnVector data) { this.data = data; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8849a20d6ceb5..2f5fb360b226f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -42,15 +42,6 @@ public final class ColumnarBatch { private int numRows; final ColumnVector[] columns; - // True if the row is filtered. - private final boolean[] filteredRows; - - // Column indices that cannot have null values. - private final Set nullFilteredColumns; - - // Total number of rows that have been filtered. - private int numRowsFiltered = 0; - // Staging row returned from getRow. final ColumnarRow row; @@ -68,24 +59,18 @@ public void close() { * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ public Iterator rowIterator() { - final int maxRows = ColumnarBatch.this.numRows(); - final ColumnarRow row = new ColumnarRow(this); + final int maxRows = numRows; + final ColumnarRow row = new ColumnarRow(columns); return new Iterator() { int rowId = 0; @Override public boolean hasNext() { - while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { - ++rowId; - } return rowId < maxRows; } @Override public ColumnarRow next() { - while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { - ++rowId; - } if (rowId >= maxRows) { throw new NoSuchElementException(); } @@ -109,31 +94,15 @@ public void reset() { ((WritableColumnVector) columns[i]).reset(); } } - if (this.numRowsFiltered > 0) { - Arrays.fill(filteredRows, false); - } this.numRows = 0; - this.numRowsFiltered = 0; } /** - * Sets the number of rows that are valid. Additionally, marks all rows as "filtered" if one or - * more of their attributes are part of a non-nullable column. + * Sets the number of rows that are valid. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); this.numRows = numRows; - - for (int ordinal : nullFilteredColumns) { - if (columns[ordinal].numNulls() != 0) { - for (int rowId = 0; rowId < numRows; rowId++) { - if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) { - filteredRows[rowId] = true; - ++numRowsFiltered; - } - } - } - } } /** @@ -146,14 +115,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the number of valid rows. - */ - public int numValidRows() { - assert(numRowsFiltered <= numRows); - return numRows - numRowsFiltered; - } - /** * Returns the schema that makes up this batch. */ @@ -169,17 +130,6 @@ public int numValidRows() { */ public ColumnVector column(int ordinal) { return columns[ordinal]; } - /** - * Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient - * projections. - */ - public void setColumn(int ordinal, ColumnVector column) { - if (column instanceof OffHeapColumnVector) { - throw new UnsupportedOperationException("Need to ref count columns."); - } - columns[ordinal] = column; - } - /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ @@ -190,30 +140,10 @@ public ColumnarRow getRow(int rowId) { return row; } - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered(int rowId) { - assert(!filteredRows[rowId]); - filteredRows[rowId] = true; - ++numRowsFiltered; - } - - /** - * Marks a given column as non-nullable. Any row that has a NULL value for the corresponding - * attribute is filtered out. - */ - public void filterNullsInColumn(int ordinal) { - nullFilteredColumns.add(ordinal); - } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; this.columns = columns; this.capacity = capacity; - this.nullFilteredColumns = new HashSet<>(); - this.filteredRows = new boolean[capacity]; - this.row = new ColumnarRow(this); + this.row = new ColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index c75adafd69461..98a907322713b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -32,28 +31,11 @@ */ public final class ColumnarRow extends InternalRow { protected int rowId; - private final ColumnarBatch parent; - private final int fixedLenRowSize; private final ColumnVector[] columns; private final WritableColumnVector[] writableColumns; - // Ctor used if this is a top level row. - ColumnarRow(ColumnarBatch parent) { - this.parent = parent; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); - this.columns = parent.columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - // Ctor used if this is a struct. ColumnarRow(ColumnVector[] columns) { - this.parent = null; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); this.columns = columns; this.writableColumns = new WritableColumnVector[this.columns.length]; for (int i = 0; i < this.columns.length; i++) { @@ -63,14 +45,6 @@ public final class ColumnarRow extends InternalRow { } } - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered() { - parent.markFiltered(rowId); - } - public ColumnVector[] columns() { return columns; } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index ba2903babbba8..57958f7239224 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1731,7 +1731,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) - assert(schema.equals(outputRowIter.schema)) + assert(schema == outputRowIter.schema) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 0917f188b9799..de7a5795b4796 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -295,48 +295,24 @@ object ParquetReadBenchmark { } } - benchmark.addCase("PR Vectorized (Null Filtering)") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - batch.filterNullsInColumn(0) - batch.filterNullsInColumn(1) - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - sum += rowIterator.next().getUTF8String(0).numBytes() - } - } - } finally { - reader.close() - } - } - } - /* Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X PR Vectorized 833 / 846 12.6 79.4 1.5X - PR Vectorized (Null Filtering) 732 / 782 14.3 69.8 1.7X Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X PR Vectorized 732 / 772 14.3 69.8 1.4X - PR Vectorized (Null Filtering) 725 / 790 14.5 69.1 1.4X Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X PR Vectorized 190 / 200 55.1 18.2 1.7X - PR Vectorized (Null Filtering) 168 / 172 62.2 16.1 1.9X */ benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4cfc776e51db1..4a6c8f5521d18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -919,7 +919,6 @@ class ColumnarBatchSuite extends SparkFunSuite { val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.numValidRows() == 0) assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) @@ -933,7 +932,6 @@ class ColumnarBatchSuite extends SparkFunSuite { // Verify the results of the row. assert(batch.numCols() == 4) assert(batch.numRows() == 1) - assert(batch.numValidRows() == 1) assert(batch.rowIterator().hasNext == true) assert(batch.rowIterator().hasNext == true) @@ -957,16 +955,9 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) assert(it.hasNext == false) - // Filter out the row. - row.markFiltered() - assert(batch.numRows() == 1) - assert(batch.numValidRows() == 0) - assert(batch.rowIterator().hasNext == false) - // Reset and add 3 rows batch.reset() assert(batch.numRows() == 0) - assert(batch.numValidRows() == 0) assert(batch.rowIterator().hasNext == false) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] @@ -1002,26 +993,12 @@ class ColumnarBatchSuite extends SparkFunSuite { // Verify assert(batch.numRows() == 3) - assert(batch.numValidRows() == 3) val it2 = batch.rowIterator() rowEquals(it2.next(), Row(null, 2.2, 2, "abc")) rowEquals(it2.next(), Row(3, null, 3, "")) rowEquals(it2.next(), Row(4, 4.4, 4, "world")) assert(!it.hasNext) - // Filter out some rows and verify - batch.markFiltered(1) - assert(batch.numValidRows() == 2) - val it3 = batch.rowIterator() - rowEquals(it3.next(), Row(null, 2.2, 2, "abc")) - rowEquals(it3.next(), Row(4, 4.4, 4, "world")) - assert(!it.hasNext) - - batch.markFiltered(2) - assert(batch.numValidRows() == 1) - val it4 = batch.rowIterator() - rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close() }} } @@ -1176,35 +1153,6 @@ class ColumnarBatchSuite extends SparkFunSuite { testRandomRows(false, 30) } - test("null filtered columns") { - val NUM_ROWS = 10 - val schema = new StructType() - .add("key", IntegerType, nullable = false) - .add("value", StringType, nullable = true) - for (numNulls <- List(0, NUM_ROWS / 2, NUM_ROWS)) { - val rows = mutable.ArrayBuffer.empty[Row] - for (i <- 0 until NUM_ROWS) { - val row = if (i < numNulls) Row.fromSeq(Seq(i, null)) else Row.fromSeq(Seq(i, i.toString)) - rows += row - } - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) - batch.filterNullsInColumn(1) - batch.setNumRows(NUM_ROWS) - assert(batch.numRows() == NUM_ROWS) - val it = batch.rowIterator() - // Top numNulls rows should be filtered - var k = numNulls - while (it.hasNext) { - assert(it.next().getInt(0) == k) - k += 1 - } - assert(k == NUM_ROWS) - batch.close() - }} - } - } - test("mutable ColumnarBatch rows") { val NUM_ITERS = 10 val types = Array( From d00b55d4b25ba0bf92983ff1bb47d8528e943737 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 17 Nov 2017 07:53:53 -0600 Subject: [PATCH 261/712] [SPARK-22540][SQL] Ensure HighlyCompressedMapStatus calculates correct avgSize ## What changes were proposed in this pull request? Ensure HighlyCompressedMapStatus calculates correct avgSize ## How was this patch tested? New unit test added. Author: yucai Closes #19765 from yucai/avgsize. --- .../apache/spark/scheduler/MapStatus.scala | 10 +++++---- .../spark/scheduler/MapStatusSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 5e45b375ddd45..2ec2f2031aa45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -197,7 +197,8 @@ private[spark] object HighlyCompressedMapStatus { // block as being non-empty (or vice-versa) when using the average block size. var i = 0 var numNonEmptyBlocks: Int = 0 - var totalSize: Long = 0 + var numSmallBlocks: Int = 0 + var totalSmallBlockSize: Long = 0 // From a compression standpoint, it shouldn't matter whether we track empty or non-empty // blocks. From a performance standpoint, we benefit from tracking empty blocks because // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. @@ -214,7 +215,8 @@ private[spark] object HighlyCompressedMapStatus { // Huge blocks are not included in the calculation for average size, thus size for smaller // blocks is more accurate. if (size < threshold) { - totalSize += size + totalSmallBlockSize += size + numSmallBlocks += 1 } else { hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) } @@ -223,8 +225,8 @@ private[spark] object HighlyCompressedMapStatus { } i += 1 } - val avgSize = if (numNonEmptyBlocks > 0) { - totalSize / numNonEmptyBlocks + val avgSize = if (numSmallBlocks > 0) { + totalSmallBlockSize / numSmallBlocks } else { 0 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 144e5afdcdd78..2155a0f2b6c21 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -98,6 +98,28 @@ class MapStatusSuite extends SparkFunSuite { } } + test("SPARK-22540: ensure HighlyCompressedMapStatus calculates correct avgSize") { + val threshold = 1000 + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, threshold.toString) + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = (0L to 3000L).toArray + val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) + val avg = smallBlockSizes.sum / smallBlockSizes.length + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val status1 = compressAndDecompressMapStatus(status) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + assert(status1.location == loc) + for (i <- 0 until threshold) { + val estimate = status1.getSizeForBlock(i) + if (sizes(i) > 0) { + assert(estimate === avg) + } + } + } + def compressAndDecompressMapStatus(status: MapStatus): MapStatus = { val ser = new JavaSerializer(new SparkConf) val buf = ser.newInstance().serialize(status) From 7d039e0c0af0931a1696d89e76f455ae4adf277d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 17 Nov 2017 16:43:08 +0100 Subject: [PATCH 262/712] [SPARK-22409] Introduce function type argument in pandas_udf ## What changes were proposed in this pull request? * Add a "function type" argument to pandas_udf. * Add a new public enum class `PandasUdfType` in pyspark.sql.functions * Refactor udf related code from pyspark.sql.functions to pyspark.sql.udf * Merge "PythonUdfType" and "PythonEvalType" into a single enum class "PythonEvalType" Example: ``` from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf('double', PandasUDFType.SCALAR): def plus_one(v): return v + 1 ``` ## Design doc https://docs.google.com/document/d/1KlLaa-xJ3oz28xlEJqXyCAHU3dwFYkFs_ixcUXrJNTc/edit ## How was this patch tested? Added PandasUDFTests ## TODO: * [x] Implement proper enum type for `PandasUDFType` * [x] Update documentation * [x] Add more tests in PandasUDFTests Author: Li Jin Closes #19630 from icexelloss/spark-22409-pandas-udf-type. --- .../spark/api/python/PythonRunner.scala | 8 +- python/pyspark/rdd.py | 16 ++ python/pyspark/serializers.py | 7 - python/pyspark/sql/catalog.py | 7 +- python/pyspark/sql/functions.py | 222 ++++++------------ python/pyspark/sql/group.py | 49 +--- python/pyspark/sql/tests.py | 221 ++++++++++++----- python/pyspark/sql/udf.py | 161 +++++++++++++ python/pyspark/worker.py | 39 ++- .../spark/sql/RelationalGroupedDataset.scala | 9 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 12 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 +- .../python/BatchEvalPythonExecSuite.scala | 4 +- 16 files changed, 490 insertions(+), 284 deletions(-) create mode 100644 python/pyspark/sql/udf.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 9989f68f8508c..f524de68fbce0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -34,9 +34,11 @@ import org.apache.spark.util._ */ private[spark] object PythonEvalType { val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 - val SQL_PANDAS_GROUPED_UDF = 3 + + val SQL_BATCHED_UDF = 100 + + val SQL_PANDAS_SCALAR_UDF = 200 + val SQL_PANDAS_GROUP_MAP_UDF = 201 } /** diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ea993c572fafd..340bc3a6b7470 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -56,6 +56,22 @@ __all__ = ["RDD"] +class PythonEvalType(object): + """ + Evaluation type of python rdd. + + These values are internal to PySpark. + + These values should match values in org.apache.spark.api.python.PythonEvalType. + """ + NON_UDF = 0 + + SQL_BATCHED_UDF = 100 + + SQL_PANDAS_SCALAR_UDF = 200 + SQL_PANDAS_GROUP_MAP_UDF = 201 + + def portable_hash(x): """ This function returns consistent hash code for builtin types, especially diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e0afdafbfcd62..b95de2c804394 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -82,13 +82,6 @@ class SpecialLengths(object): START_ARROW_STREAM = -6 -class PythonEvalType(object): - NON_UDF = 0 - SQL_BATCHED_UDF = 1 - SQL_PANDAS_UDF = 2 - SQL_PANDAS_GROUPED_UDF = 3 - - class Serializer(object): def dump_stream(self, iterator, stream): diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 5f25dce161963..659bc65701a0c 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -19,9 +19,9 @@ from collections import namedtuple from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import IntegerType, StringType, StructType @@ -256,7 +256,8 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - udf = UserDefinedFunction(f, returnType, name) + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 087ce7caa89c8..b631e2041706f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,11 +27,12 @@ from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import StringType, DataType +from pyspark.sql.udf import UserDefinedFunction, _create_udf def _create_function(name, doc=""): @@ -2062,132 +2063,12 @@ def map_values(col): # ---------------------------- User Defined Function ---------------------------------- -def _wrap_function(sc, func, returnType): - command = (func, returnType) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) - - -class PythonUdfType(object): - # row-at-a-time UDFs - NORMAL_UDF = 0 - # scalar vectorized UDFs - PANDAS_UDF = 1 - # grouped vectorized UDFs - PANDAS_GROUPED_UDF = 2 - - -class UserDefinedFunction(object): - """ - User defined function in Python - - .. versionadded:: 1.3 - """ - def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): - if not callable(func): - raise TypeError( - "Not a function or callable (__call__ is not defined): " - "{0}".format(type(func))) - - self.func = func - self._returnType = returnType - # Stores UserDefinedPythonFunctions jobj, once initialized - self._returnType_placeholder = None - self._judf_placeholder = None - self._name = name or ( - func.__name__ if hasattr(func, '__name__') - else func.__class__.__name__) - self.pythonUdfType = pythonUdfType - - @property - def returnType(self): - # This makes sure this is called after SparkContext is initialized. - # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. - if self._returnType_placeholder is None: - if isinstance(self._returnType, DataType): - self._returnType_placeholder = self._returnType - else: - self._returnType_placeholder = _parse_datatype_string(self._returnType) - return self._returnType_placeholder - - @property - def _judf(self): - # It is possible that concurrent access, to newly created UDF, - # will initialize multiple UserDefinedPythonFunctions. - # This is unlikely, doesn't affect correctness, - # and should have a minimal performance impact. - if self._judf_placeholder is None: - self._judf_placeholder = self._create_judf() - return self._judf_placeholder - - def _create_judf(self): - from pyspark.sql import SparkSession - - spark = SparkSession.builder.getOrCreate() - sc = spark.sparkContext - - wrapped_func = _wrap_function(sc, self.func, self.returnType) - jdt = spark._jsparkSession.parseDataType(self.returnType.json()) - judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.pythonUdfType) - return judf - - def __call__(self, *cols): - judf = self._judf - sc = SparkContext._active_spark_context - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) - - def _wrapped(self): - """ - Wrap this udf with a function and attach docstring from func - """ - - # It is possible for a callable instance without __name__ attribute or/and - # __module__ attribute to be wrapped here. For example, functools.partial. In this case, - # we should avoid wrapping the attributes from the wrapped function to the wrapper - # function. So, we take out these attribute names from the default names to set and - # then manually assign it after being wrapped. - assignments = tuple( - a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') - - @functools.wraps(self.func, assigned=assignments) - def wrapper(*args): - return self(*args) - - wrapper.__name__ = self._name - wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') - else self.func.__class__.__module__) - - wrapper.func = self.func - wrapper.returnType = self.returnType - wrapper.pythonUdfType = self.pythonUdfType - - return wrapper - - -def _create_udf(f, returnType, pythonUdfType): - - def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): - if pythonUdfType == PythonUdfType.PANDAS_UDF: - import inspect - argspec = inspect.getargspec(f) - if len(argspec.args) == 0 and argspec.varargs is None: - raise ValueError( - "0-arg pandas_udfs are not supported. " - "Instead, create a 1-arg pandas_udf and ignore the arg in your function." - ) - udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) - return udf_obj._wrapped() - - # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf - if f is None or isinstance(f, (str, DataType)): - # If DataType has been passed as a positional argument - # for decorator use it as a returnType - return_type = f or returnType - return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) - else: - return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) +class PandasUDFType(object): + """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. + """ + SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF @since(1.3) @@ -2228,33 +2109,47 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) + # decorator @udf, @udf(), @udf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_create_udf, returnType=return_type, + evalType=PythonEvalType.SQL_BATCHED_UDF) + else: + return _create_udf(f=f, returnType=returnType, + evalType=PythonEvalType.SQL_BATCHED_UDF) @since(2.3) -def pandas_udf(f=None, returnType=StringType()): +def pandas_udf(f=None, returnType=None, functionType=None): """ Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object + :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. + Default: SCALAR. - The user-defined function can define one of the following transformations: + The function type of the UDF can be one of the following: - 1. One or more `pandas.Series` -> A `pandas.Series` + 1. SCALAR - This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. + A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The returnType should be a primitive data type, e.g., `DoubleType()`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql.types import IntegerType, StringType >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) + >>> @pandas_udf(StringType()) ... def to_upper(s): ... return s.str.upper() ... - >>> @pandas_udf(returnType="integer") + >>> @pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... @@ -2267,20 +2162,24 @@ def pandas_udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ - 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + 2. GROUP_MAP - This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. + The length of the returned `pandas.DataFrame` can be arbitrary. + + Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -2291,10 +2190,6 @@ def pandas_udf(f=None, returnType=StringType()): | 2| 1.1094003924504583| +---+-------------------+ - .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` - because it defines a `DataFrame` transformation rather than a `Column` - transformation. - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. @@ -2306,7 +2201,44 @@ def pandas_udf(f=None, returnType=StringType()): rows that do not satisfy the conditions, the suggested workaround is to incorporate the condition logic into the functions. """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) + # decorator @pandas_udf(returnType, functionType) + is_decorator = f is None or isinstance(f, (str, DataType)) + + if is_decorator: + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + + if functionType is not None: + # @pandas_udf(dataType, functionType=functionType) + # @pandas_udf(returnType=dataType, functionType=functionType) + eval_type = functionType + elif returnType is not None and isinstance(returnType, int): + # @pandas_udf(dataType, functionType) + eval_type = returnType + else: + # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + else: + return_type = returnType + + if functionType is not None: + eval_type = functionType + else: + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + if return_type is None: + raise ValueError("Invalid returnType: returnType can not be None") + + if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + raise ValueError("Invalid functionType: " + "functionType must be one the values from PandasUDFType") + + if is_decorator: + return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) + else: + return _create_udf(f=f, returnType=return_type, evalType=eval_type) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index e11388d604312..4d47dd6a3e878 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -16,10 +16,10 @@ # from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import PythonUdfType, UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -214,15 +214,15 @@ def apply(self, udf): :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` - >>> from pyspark.sql.functions import pandas_udf + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -236,44 +236,13 @@ def apply(self, udf): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - import inspect - # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ - or len(inspect.getargspec(udf.func).args) != 1: - raise ValueError("The argument to apply must be a 1-arg pandas_udf") - if not isinstance(udf.returnType, StructType): - raise ValueError("The returnType of the pandas_udf must be a StructType") - + or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "GROUP_MAP.") df = self._df - func = udf.func - returnType = udf.returnType - - # The python executors expects the function to use pd.Series as input and output - # So we to create a wrapper function that turns that to a pd.DataFrame before passing - # down to the user function, then turn the result pd.DataFrame back into pd.Series - columns = df.columns - - def wrapped(*cols): - from pyspark.sql.types import to_arrow_type - import pandas as pd - result = func(pd.concat(cols, axis=1, keys=columns)) - if not isinstance(result, pd.DataFrame): - raise TypeError("Return type of the user-defined function should be " - "Pandas.DataFrame, but is {}".format(type(result))) - if not len(result.columns) == len(returnType): - raise RuntimeError( - "Number of columns of the returned Pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] - - udf_obj = UserDefinedFunction( - wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) - udf_column = udf_obj(*[df[col] for col in df.columns]) + udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ef592c2356a8c..762afe0d730f3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3259,6 +3259,129 @@ def test_schema_conversion_roundtrip(self): self.assertEquals(self.schema, schema_rt) +class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + udf = pandas_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), + PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, returnType='v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_pandas_udf_decorator(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.types import StructType, StructField, DoubleType + + @pandas_udf(DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf(schema, PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf('v double', PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_udf_wrong_arg(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaises(ParseException): + @pandas_udf('blah') + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): + @pandas_udf(functionType=PandasUDFType.SCALAR) + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): + @pandas_udf('double', 100) + def foo(x): + return x + + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + @pandas_udf(LongType(), PandasUDFType.SCALAR) + def zero_with_type(): + return 1 + + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + def foo(k, v): + return k + + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): @@ -3355,23 +3478,6 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) - def test_vectorized_udf_zero_parameter(self): - from pyspark.sql.functions import pandas_udf - error_str = '0-arg pandas_udfs.*not.*supported' - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, error_str): - pandas_udf(lambda: 1, LongType()) - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf - def zero_no_type(): - return 1 - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf(LongType()) - def zero_with_type(): - return 1 - def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3570,7 +3676,7 @@ def check_records_per_batch(x): return x result = df.select(check_records_per_batch(col("id"))) - self.assertEquals(df.collect(), result.collect()) + self.assertEqual(df.collect(), result.collect()) finally: if orig_value is None: self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") @@ -3595,7 +3701,7 @@ def data(self): .withColumn("v", explode(col('vs'))).drop('vs') def test_simple(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( @@ -3604,21 +3710,22 @@ def test_simple(self): [StructField('id', LongType()), StructField('v', IntegerType()), StructField('v1', DoubleType()), - StructField('v2', LongType())])) + StructField('v2', LongType())]), + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_decorator(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('v1', DoubleType()), - StructField('v2', LongType())])) + @pandas_udf( + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -3627,12 +3734,14 @@ def foo(pdf): self.assertFramesEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + 'id long, v double', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3640,13 +3749,13 @@ def test_coerce(self): self.assertFramesEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3659,13 +3768,13 @@ def normalize(pdf): self.assertFramesEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3678,57 +3787,63 @@ def normalize(pdf): self.assertFramesEqual(expected, result) def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - "id long, v int, v1 double, v2 long") + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', StringType())])) + 'id long, v string', + PandasUDFType.GROUP_MAP + ) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.groupby('id').apply(foo).sort('id').toPandas() def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum + from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType df = self.data with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(df.v + 1) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid function'): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'returnType'): - df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), + PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py new file mode 100644 index 0000000000000..c3301a41ccd5a --- /dev/null +++ b/python/pyspark/sql/udf.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined function related classes and functions +""" +import functools + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string + + +def _wrap_function(sc, func, returnType): + command = (func, returnType) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + +def _create_udf(f, returnType, evalType): + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "Invalid function: 0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) + + elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) != 1: + raise ValueError( + "Invalid function: pandas_udfs with function type GROUP_MAP " + "must take a single arg that is a pandas DataFrame." + ) + + # Set the name of the UserDefinedFunction object to be the name of function f + udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + return udf_obj._wrapped() + + +class UserDefinedFunction(object): + """ + User defined function in Python + + .. versionadded:: 1.3 + """ + def __init__(self, func, + returnType=StringType(), name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF): + if not callable(func): + raise TypeError( + "Invalid function: not a function or callable (__call__ is not defined): " + "{0}".format(type(func))) + + if not isinstance(returnType, (DataType, str)): + raise TypeError( + "Invalid returnType: returnType should be DataType or str " + "but is {}".format(returnType)) + + if not isinstance(evalType, int): + raise TypeError( + "Invalid evalType: evalType should be an int but is {}".format(evalType)) + + self.func = func + self._returnType = returnType + # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None + self._judf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') + else func.__class__.__name__) + self.evalType = evalType + + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + + if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + and not isinstance(self._returnType_placeholder, StructType): + raise ValueError("Invalid returnType: returnType must be a StructType for " + "pandas_udf with function type GROUP_MAP") + + return self._returnType_placeholder + + @property + def _judf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judf_placeholder is None: + self._judf_placeholder = self._create_judf() + return self._judf_placeholder + + def _create_judf(self): + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_func = _wrap_function(sc, self.func, self.returnType) + jdt = spark._jsparkSession.parseDataType(self.returnType.json()) + judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( + self._name, wrapped_func, jdt, self.evalType) + return judf + + def __call__(self, *cols): + judf = self._judf + sc = SparkContext._active_spark_context + return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) + def wrapper(*args): + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) + + wrapper.func = self.func + wrapper.returnType = self.returnType + wrapper.evalType = self.evalType + + return wrapper diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5e100e0a9a95d..939643071943a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,8 +29,9 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles +from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type from pyspark import shuffle @@ -73,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_udf(f, return_type): +def wrap_pandas_scalar_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -89,6 +90,26 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) +def wrap_pandas_group_map_udf(f, return_type): + def wrapped(*series): + import pandas as pd + + result = f(pd.concat(series, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + return wrapped + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -99,12 +120,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = f else: row_func = chain(row_func, f) + # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: - # a groupby apply udf has already been wrapped under apply() - return arg_offsets, row_func + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) @@ -127,8 +148,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 3e4edd4ea8cf3..a009c00b0abc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -449,10 +450,10 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, - "Must pass a grouped vectorized python udf") + require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + "Must pass a group map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the vectorized python udf must be a StructType") + "The returnType of the udf must be a StructType") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index bcda2dae92e53..e27210117a1e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e15e760136e81..f5a4cbc4793e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -148,15 +149,18 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { - throw new IllegalArgumentException("Can not use grouped vectorized UDFs") - } + require(validUdfs.forall(udf => + udf.evalType == PythonEvalType.SQL_BATCHED_UDF || + udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ), "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { + val evaluation = validUdfs.partition( + _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e1e04e34e0c71..ee495814b8255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -95,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9c07c7638de57..ef27fbc2db7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - pythonUdfType: Int) + evalType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index b2fe6c300846a..348e49e473ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,15 +22,6 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType -private[spark] object PythonUdfType { - // row-at-a-time UDFs - val NORMAL_UDF = 0 - // scalar vectorized UDFs - val PANDAS_UDF = 1 - // grouped vectorized UDFs - val PANDAS_GROUPED_UDF = 2 -} - /** * A user-defined Python function. This is used by the Python API. */ @@ -38,10 +29,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonUdfType: Int) { + pythonEvalType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonUdfType) + PythonUDF(name, func, dataType, e, pythonEvalType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 95b21fc9f16ae..53d3f34567518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.python.PythonFunction +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonUdfType = PythonUdfType.NORMAL_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF) From fccb337f9d1e44a83cfcc00ce33eae1fad367695 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Nov 2017 17:43:40 +0100 Subject: [PATCH 263/712] [SPARK-22538][ML] SQLTransformer should not unpersist possibly cached input dataset ## What changes were proposed in this pull request? `SQLTransformer.transform` unpersists input dataset when dropping temporary view. We should not change input dataset's cache status. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19772 from viirya/SPARK-22538. --- .../org/apache/spark/ml/feature/SQLTransformer.scala | 3 ++- .../spark/ml/feature/SQLTransformerSuite.scala | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 62c1972aab12c..0fb1d8c5dc579 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -70,7 +70,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) val result = dataset.sparkSession.sql(realStatement) - dataset.sparkSession.catalog.dropTempView(tableName) + // Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset. + dataset.sparkSession.sessionState.catalog.dropTempView(tableName) result } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 753f890c48301..673a146e619f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.storage.StorageLevel class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -60,4 +61,15 @@ class SQLTransformerSuite val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) assert(outputSchema === expected) } + + test("SPARK-22538: SQLTransformer should not unpersist given dataset") { + val df = spark.range(10) + df.cache() + df.count() + assert(df.storageLevel != StorageLevel.NONE) + new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transform(df) + assert(df.storageLevel != StorageLevel.NONE) + } } From bf0c0ae2dcc7fd1ce92cd0fb4809bb3d65b2e309 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 17 Nov 2017 15:35:24 -0800 Subject: [PATCH 264/712] [SPARK-22544][SS] FileStreamSource should use its own hadoop conf to call globPathIfNecessary ## What changes were proposed in this pull request? Pass the FileSystem created using the correct Hadoop conf into `globPathIfNecessary` so that it can pick up user's hadoop configurations, such as credentials. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19771 from zsxwing/fix-file-stream-conf. --- .../spark/sql/execution/streaming/FileStreamSource.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index f17417343e289..0debd7db84757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -47,8 +47,9 @@ class FileStreamSource( private val hadoopConf = sparkSession.sessionState.newHadoopConf() + @transient private val fs = new Path(path).getFileSystem(hadoopConf) + private val qualifiedBasePath: Path = { - val fs = new Path(path).getFileSystem(hadoopConf) fs.makeQualified(new Path(path)) // can contains glob patterns } @@ -187,7 +188,7 @@ class FileStreamSource( if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None private def allFilesUsingInMemoryFileIndex() = { - val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(fs, qualifiedBasePath) val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) fileIndex.allFiles() } From d54bfec2e07f2eb934185402f915558fe27b9312 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 18 Nov 2017 19:40:06 +0100 Subject: [PATCH 265/712] [SPARK-22498][SQL] Fix 64KB JVM bytecode limit problem with concat ## What changes were proposed in this pull request? This PR changes `concat` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `concat` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19728 from kiszk/SPARK-22498. --- .../expressions/stringExpressions.scala | 30 +++++++++++++------ .../expressions/StringExpressionsSuite.scala | 6 ++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c341943187820..d5bb7e95769e5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -63,15 +63,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.value}" - }.mkString(", ") - ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - UTF8String ${ev.value} = UTF8String.concat($inputs); - if (${ev.value} == null) { - ${ev.isNull} = true; - } + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions(inputs, "valueConcat", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + } else { + inputs.mkString("\n") + } + ev.copy(s""" + UTF8String[] $args = new UTF8String[${evals.length}]; + $codes + UTF8String ${ev.value} = UTF8String.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; """) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 18ef4bc37c2b5..aa9d5a0aa95e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -45,6 +45,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("SPARK-22498: Concat should not generate codes beyond 64KB") { + val N = 5000 + val strs = (1 to N).map(x => s"s$x") + checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From b10837ab1a7bef04bf7a2773b9e44ed9206643fe Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 20 Nov 2017 13:32:01 +0900 Subject: [PATCH 266/712] [SPARK-22557][TEST] Use ThreadSignaler explicitly ## What changes were proposed in this pull request? ScalaTest 3.0 uses an implicit `Signaler`. This PR makes it sure all Spark tests uses `ThreadSignaler` explicitly which has the same default behavior of interrupting a thread on the JVM like ScalaTest 2.2.x. This will reduce potential flakiness. ## How was this patch tested? This is testsuite-only update. This should passes the Jenkins tests. Author: Dongjoon Hyun Closes #19784 from dongjoon-hyun/use_thread_signaler. --- .../test/scala/org/apache/spark/DistributedSuite.scala | 7 +++++-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 5 ++++- .../test/scala/org/apache/spark/UnpersistSuite.scala | 8 ++++++-- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++++- .../org/apache/spark/rdd/AsyncRDDActionsSuite.scala | 5 ++++- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 5 ++++- .../OutputCommitCoordinatorIntegrationSuite.scala | 5 ++++- .../org/apache/spark/storage/BlockManagerSuite.scala | 10 ++++++++-- .../scala/org/apache/spark/util/EventLoopSuite.scala | 5 ++++- .../streaming/ProcessingTimeExecutorSuite.scala | 8 +++++--- .../org/apache/spark/sql/streaming/StreamTest.scala | 2 ++ .../apache/spark/sql/hive/SparkSubmitTestUtils.scala | 5 ++++- .../org/apache/spark/streaming/ReceiverSuite.scala | 5 +++-- .../apache/spark/streaming/StreamingContextSuite.scala | 5 +++-- .../spark/streaming/receiver/BlockGeneratorSuite.scala | 7 ++++--- 15 files changed, 68 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index f8005610f7e4f..ea9f6d2fc20f4 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.scalatest.Matchers -import org.scalatest.concurrent.TimeLimits._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} import org.apache.spark.security.EncryptionFunSuite @@ -30,7 +30,10 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext - with EncryptionFunSuite { + with EncryptionFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler val clusterUrl = "local-cluster[2,1,1024]" diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index be80d278fcea8..962945e5b6bb1 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils class DriverSuite extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val masters = Table("master", "local", "local-cluster[2,1,1024]") diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index bc3f58cf2a35d..b58a3ebe6e4c9 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark -import org.scalatest.concurrent.TimeLimits._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} -class UnpersistSuite extends SparkFunSuite with LocalSparkContext { +class UnpersistSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + test("unpersist RDD") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index cfbf56fb8c369..d0a34c5cdcf57 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -102,6 +102,9 @@ class SparkSubmitSuite import SparkSubmitSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -1016,6 +1019,10 @@ class SparkSubmitSuite } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index de0e71a332f23..24b0144a38bd2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -34,6 +34,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim @transient private var sc: SparkContext = _ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeAll() { super.beforeAll() sc = new SparkContext("local[2]", "test") diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6222e576d1ce9..d395e09969453 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ @@ -102,6 +102,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi import DAGSchedulerSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index a27dadcf49bfc..d6ff5bb33055c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Seconds, Span} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} @@ -34,6 +34,9 @@ class OutputCommitCoordinatorIntegrationSuite with LocalSparkContext with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + override def beforeAll(): Unit = { super.beforeAll() val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d45c194d31adc..f3e8a2ed1d562 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,8 +31,8 @@ import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager @@ -57,10 +57,13 @@ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with LocalSparkContext with ResetSystemProperties - with EncryptionFunSuite { + with EncryptionFunSuite with TimeLimits { import BlockManagerSuite._ + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null @@ -1450,6 +1453,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private object BlockManagerSuite { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + private implicit class BlockManagerTestUtils(store: BlockManager) { def dropFromMemoryIfExists( diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index f4f8388f5f19f..550745771750c 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -23,13 +23,16 @@ import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + test("EventLoop") { val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 519e3c01afe8a..80c76915e4c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -22,16 +22,18 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import org.eclipse.jetty.util.ConcurrentHashSet -import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.sql.streaming.util.StreamManualClock -class ProcessingTimeExecutorSuite extends SparkFunSuite { +class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler val timeout = 10.seconds diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 70b39b934071a..e68fca050571f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -69,7 +69,9 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} */ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with BeforeAndAfterAll { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + override def afterAll(): Unit = { super.afterAll() StateStore.stop() // stop the state store maintenance thread and unload store providers diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala index ede44df4afe11..68ed97d6d1f5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -23,7 +23,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer -import org.scalatest.concurrent.TimeLimits +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -33,6 +33,9 @@ import org.apache.spark.util.Utils trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5fc626c1f78b8..145c48e5a9a72 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -38,6 +38,9 @@ import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val signaler: Signaler = ThreadSignaler + test("receiver life cycle") { val receiver = new FakeReceiver @@ -60,8 +63,6 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { // Verify that the receiver intercept[Exception] { - // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x - implicit val signaler: Signaler = ThreadSignaler failAfter(200 millis) { executingThread.join() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5810e73f4098b..52c8959351fe7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -44,6 +44,9 @@ import org.apache.spark.util.Utils class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits with Logging { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val signaler: Signaler = ThreadSignaler + val master = "local[2]" val appName = this.getClass.getSimpleName val batchDuration = Milliseconds(500) @@ -406,8 +409,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // test whether awaitTermination() does not exit if not time is given val exception = intercept[Exception] { - // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x - implicit val signaler: Signaler = ThreadSignaler failAfter(1000 millis) { ssc.awaitTermination() throw new Exception("Did not wait for stop") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 898da4445e464..580f831548cd5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -24,18 +24,19 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ -import org.scalatest.concurrent.{Signaler, ThreadSignaler} +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.TimeLimits._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.ManualClock -class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits { + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val blockIntervalMs = 10 private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") @volatile private var blockGenerator: BlockGenerator = null From 57c5514de9dba1c14e296f85fb13fef23ce8c73f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 20 Nov 2017 13:34:06 +0900 Subject: [PATCH 267/712] [SPARK-22554][PYTHON] Add a config to control if PySpark should use daemon or not for workers ## What changes were proposed in this pull request? This PR proposes to add a flag to control if PySpark should use daemon or not. Actually, SparkR already has a flag for useDaemon: https://github.com/apache/spark/blob/478fbc866fbfdb4439788583281863ecea14e8af/core/src/main/scala/org/apache/spark/api/r/RRunner.scala#L362 It'd be great if we have this flag too. It makes easier to debug Windows specific issue. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #19782 from HyukjinKwon/use-daemon-flag. --- .../org/apache/spark/api/python/PythonWorkerFactory.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index fc595ae9e4563..f53c6178047f5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -38,7 +38,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently // only works on UNIX-based systems now because it uses signals for child management, so we can // also fall back to launching workers (pyspark/worker.py) directly. - val useDaemon = !System.getProperty("os.name").startsWith("Windows") + val useDaemon = { + val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) + + // This flag is ignored on Windows as it's unable to fork. + !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled + } var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) From 3c3eebc8734e36e61f4627e2c517fbbe342b3b42 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 20 Nov 2017 12:40:16 +0100 Subject: [PATCH 268/712] [SPARK-20101][SQL] Use OffHeapColumnVector when "spark.sql.columnVector.offheap.enable" is set to "true" This PR enables to use ``OffHeapColumnVector`` when ``spark.sql.columnVector.offheap.enable`` is set to ``true``. While ``ColumnVector`` has two implementations ``OnHeapColumnVector`` and ``OffHeapColumnVector``, only ``OnHeapColumnVector`` is always used. This PR implements the followings - Pass ``OffHeapColumnVector`` to ``ColumnarBatch.allocate()`` when ``spark.sql.columnVector.offheap.enable`` is set to ``true`` - Free all of off-heap memory regions by ``OffHeapColumnVector.close()`` - Ensure to call ``OffHeapColumnVector.close()`` Use existing tests Author: Kazuaki Ishizaki Closes #17436 from kiszk/SPARK-20101. --- .../scala/org/apache/spark/SparkConf.scala | 2 +- .../spark/internal/config/package.scala | 16 ++++++++++++++ .../apache/spark/memory/MemoryManager.scala | 7 +++--- .../memory/StaticMemoryManagerSuite.scala | 3 ++- .../memory/UnifiedMemoryManagerSuite.scala | 7 +++--- .../BlockManagerReplicationSuite.scala | 3 ++- .../spark/storage/BlockManagerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 3 ++- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++++++ .../VectorizedParquetRecordReader.java | 12 ++++++---- .../sql/execution/DataSourceScanExec.scala | 3 ++- .../columnar/InMemoryTableScanExec.scala | 17 ++++++++++++-- .../execution/datasources/FileFormat.scala | 4 +++- .../parquet/ParquetFileFormat.scala | 22 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 7 +++--- .../UnsafeFixedWidthAggregationMapSuite.scala | 3 ++- .../UnsafeKVExternalSorterSuite.scala | 6 ++--- .../benchmark/AggregateBenchmark.scala | 7 +++--- .../parquet/ParquetEncodingSuite.scala | 6 ++--- .../datasources/parquet/ParquetIOSuite.scala | 11 +++++----- .../parquet/ParquetReadBenchmark.scala | 8 ++++--- .../execution/joins/HashedRelationSuite.scala | 9 ++++---- 22 files changed, 117 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 57b3744e9c30a..ee726df7391f1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -655,7 +655,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.streaming.minRememberDuration", "1.5")), "spark.yarn.max.executor.failures" -> Seq( AlternateConfig("spark.yarn.max.worker.failures", "1.5")), - "spark.memory.offHeap.enabled" -> Seq( + MEMORY_OFFHEAP_ENABLED.key -> Seq( AlternateConfig("spark.unsafe.offHeap", "1.6")), "spark.rpc.message.maxSize" -> Seq( AlternateConfig("spark.akka.frameSize", "1.6")), diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84315f55a59ad..7be4d6b212d72 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -80,6 +80,22 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val MEMORY_OFFHEAP_ENABLED = ConfigBuilder("spark.memory.offHeap.enabled") + .doc("If true, Spark will attempt to use off-heap memory for certain operations. " + + "If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive.") + .withAlternative("spark.unsafe.offHeap") + .booleanConf + .createWithDefault(false) + + private[spark] val MEMORY_OFFHEAP_SIZE = ConfigBuilder("spark.memory.offHeap.size") + .doc("The absolute amount of memory in bytes which can be used for off-heap allocation. " + + "This setting has no impact on heap memory usage, so if your executors' total memory " + + "consumption must fit within some hard limit then be sure to shrink your JVM heap size " + + "accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true.") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ >= 0, "The off-heap memory size must not be negative") + .createWithDefault(0) + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 82442cf56154c..0641adc2ab699 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -21,6 +21,7 @@ import javax.annotation.concurrent.GuardedBy import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform @@ -54,7 +55,7 @@ private[spark] abstract class MemoryManager( onHeapStorageMemoryPool.incrementPoolSize(onHeapStorageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - protected[this] val maxOffHeapMemory = conf.getSizeAsBytes("spark.memory.offHeap.size", 0) + protected[this] val maxOffHeapMemory = conf.get(MEMORY_OFFHEAP_SIZE) protected[this] val offHeapStorageMemory = (maxOffHeapMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong @@ -194,8 +195,8 @@ private[spark] abstract class MemoryManager( * sun.misc.Unsafe. */ final val tungstenMemoryMode: MemoryMode = { - if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { - require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + if (conf.get(MEMORY_OFFHEAP_ENABLED)) { + require(conf.get(MEMORY_OFFHEAP_SIZE) > 0, "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") require(Platform.unaligned(), "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.") diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 4e31fb5589a9c..0f32fe4059fbb 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.memory import org.mockito.Mockito.when import org.apache.spark.SparkConf +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -48,7 +49,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { conf.clone .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), + .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxOnHeapStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 02b04cdbb2a5f..d56cfc183d921 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.memory import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -43,7 +44,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) + .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } @@ -305,9 +306,9 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("not enough free memory in the storage pool --OFF_HEAP") { val conf = new SparkConf() - .set("spark.memory.offHeap.size", "1000") + .set(MEMORY_OFFHEAP_SIZE.key, "1000") .set("spark.testing.memory", "1000") - .set("spark.memory.offHeap.enabled", "true") + .set(MEMORY_OFFHEAP_ENABLED.key, "true") val taskAttemptId = 0L val mm = UnifiedMemoryManager(conf, numCores = 1) val ms = makeMemoryStore(mm) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2101ba828553..3962bdc27d22c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -69,7 +70,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) + conf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f3e8a2ed1d562..629eed49b04cc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -90,7 +90,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE testConf: Option[SparkConf] = None): BlockManager = { val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) bmConf.set("spark.testing.memory", maxMem.toString) - bmConf.set("spark.memory.offHeap.size", maxMem.toString) + bmConf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) val serializer = new KryoSerializer(bmConf) val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(bmConf)) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 6a6c37873e1c2..df5f0b5335e82 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite +import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} @@ -104,7 +105,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) - .set("spark.memory.offHeap.size", "64m") + .set(MEMORY_OFFHEAP_SIZE.key, "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3452a1e715fb9..8485ed4c887d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -140,6 +140,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = + buildConf("spark.sql.columnVector.offheap.enable") + .internal() + .doc("When true, use OffHeapColumnVector in ColumnarBatch.") + .booleanConf + .createWithDefault(false) + val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") @@ -1210,6 +1217,8 @@ class SQLConf extends Serializable with Logging { def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def offHeapColumnVectorEnabled: Boolean = getConf(COLUMN_VECTOR_OFFHEAP_ENABLED) + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index e827229dceef8..669d71e60779d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -101,9 +101,13 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private boolean returnColumnarBatch; /** - * The default config on whether columnarBatch should be offheap. + * The memory mode of the columnarBatch */ - private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + private final MemoryMode MEMORY_MODE; + + public VectorizedParquetRecordReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } /** * Implementation of RecordReader API. @@ -204,11 +208,11 @@ public void initBatch( } public void initBatch() { - initBatch(DEFAULT_MEMORY_MODE, null, null); + initBatch(MEMORY_MODE, null, null); } public void initBatch(StructType partitionColumns, InternalRow partitionValues) { - initBatch(DEFAULT_MEMORY_MODE, partitionColumns, partitionValues); + initBatch(MEMORY_MODE, partitionColumns, partitionValues); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a607ec0bf8c9b..a477c23140536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -177,7 +177,8 @@ case class FileSourceScanExec( override def vectorTypes: Option[Seq[String]] = relation.fileFormat.vectorTypes( requiredSchema = requiredSchema, - partitionSchema = relation.partitionSchema) + partitionSchema = relation.partitionSchema, + relation.sparkSession.sessionState.conf) @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 2ae3f35eb1da1..3e73393b12850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.columnar +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -37,7 +38,13 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override def vectorTypes: Option[Seq[String]] = - Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + Option(Seq.fill(attributes.length)( + if (!conf.offHeapColumnVectorEnabled) { + classOf[OnHeapColumnVector].getName + } else { + classOf[OffHeapColumnVector].getName + } + )) /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. @@ -62,7 +69,12 @@ case class InMemoryTableScanExec( private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows - val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val taskContext = Option(TaskContext.get()) + val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + } else { + OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + } val columnarBatch = new ColumnarBatch( columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) columnarBatch.setNumRows(rowCount) @@ -73,6 +85,7 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } + taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) columnarBatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index e5a7aee64a4f4..d3874b58bc807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -70,7 +71,8 @@ trait FileFormat { */ def vectorTypes( requiredSchema: StructType, - partitionSchema: StructType): Option[Seq[String]] = { + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 044b1a89d57c9..2b1064955a777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -274,9 +274,15 @@ class ParquetFileFormat override def vectorTypes( requiredSchema: StructType, - partitionSchema: StructType): Option[Seq[String]] = { + partitionSchema: StructType, + sqlConf: SQLConf): Option[Seq[String]] = { Option(Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)( - classOf[OnHeapColumnVector].getName)) + if (!sqlConf.offHeapColumnVectorEnabled) { + classOf[OnHeapColumnVector].getName + } else { + classOf[OffHeapColumnVector].getName + } + )) } override def isSplitable( @@ -332,8 +338,10 @@ class ParquetFileFormat // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader: Boolean = - sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && + sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) val enableRecordFilter: Boolean = sparkSession.sessionState.conf.parquetRecordFilterEnabled @@ -364,8 +372,10 @@ class ParquetFileFormat if (pushed.isDefined) { ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } + val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { - val vectorizedReader = new VectorizedParquetRecordReader() + val vectorizedReader = + new VectorizedParquetRecordReader(enableOffHeapColumnVector && taskContext.isDefined) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -387,7 +397,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(parquetReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b2dcbe5aa9877..d98cf852a1b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -23,6 +23,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{MemoryConsumer, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -99,7 +100,7 @@ private[execution] object HashedRelation { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -232,7 +233,7 @@ private[joins] class UnsafeHashedRelation( // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -403,7 +404,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap this( new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d194f58cd1cdd..232c1beae7998 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,6 +26,7 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -63,7 +64,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") + val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 359525fcd05a2..604502f2a57d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.util.Random import org.apache.spark._ -import org.apache.spark.internal.config +import org.apache.spark.internal.config._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -112,7 +112,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) + new TestMemoryManager(new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -125,7 +125,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index a834b7cd2c69f..8f4ee8533e599 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.HashMap import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap @@ -538,7 +539,7 @@ class AggregateBenchmark extends BenchmarkBase { value.setInt(0, 555) val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -569,8 +570,8 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") - .set("spark.memory.offHeap.size", "102400000"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, s"${heap == "off"}") + .set(MEMORY_OFFHEAP_SIZE.key, "102400000"), Long.MaxValue, Long.MaxValue, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index 00799301ca8d9..edb1290ee2eb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +65,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +94,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 633cfde6ab941..44a8b25c61dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +670,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +686,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +703,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,7 +742,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = new VectorizedParquetRecordReader + val vectorizedReader = + new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index de7a5795b4796..86a3c71a3c4f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -75,6 +75,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -95,7 +96,7 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -118,7 +119,7 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -260,6 +261,7 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -277,7 +279,7 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader + val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ede63fea9606f..51f8c3325fdff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val mm = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -85,7 +86,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("test serialization empty hash map") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -157,7 +158,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("LongToUnsafeRowMap with very wide range") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), @@ -202,7 +203,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("LongToUnsafeRowMap with random keys") { val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), From c13b60e0194c90156e74d10b19f94c70675d21ae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 20 Nov 2017 12:45:21 +0100 Subject: [PATCH 269/712] [SPARK-22533][CORE] Handle deprecated names in ConfigEntry. This change hooks up the config reader to `SparkConf.getDeprecatedConfig`, so that config constants with deprecated names generate the proper warnings. It also changes two deprecated configs from the new "alternatives" system to the old deprecation system, since they're not yet hooked up to each other. Added a few unit tests to verify the desired behavior. Author: Marcelo Vanzin Closes #19760 from vanzin/SPARK-22533. --- .../main/scala/org/apache/spark/SparkConf.scala | 16 ++++++++++------ .../spark/internal/config/ConfigProvider.scala | 4 +++- .../apache/spark/internal/config/package.scala | 2 -- .../scala/org/apache/spark/SparkConfSuite.scala | 7 +++++++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ee726df7391f1..0e08ff65e4784 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.{Map => JMap} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ @@ -24,6 +25,7 @@ import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.serializer.KryoSerializer @@ -370,7 +372,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, settings)) } /** Get an optional value, applying variable substitution. */ @@ -622,7 +624,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.history.updateInterval", "1.3")), "spark.history.fs.cleaner.interval" -> Seq( AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), - "spark.history.fs.cleaner.maxAge" -> Seq( + MAX_LOG_AGE_S.key -> Seq( AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), "spark.yarn.am.waitTime" -> Seq( AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", @@ -663,8 +665,10 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( AlternateConfig("spark.yarn.access.namenodes", "2.2")), - "spark.maxRemoteBlockSizeFetchToMem" -> Seq( - AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) + MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")), + LISTENER_BUS_EVENT_QUEUE_CAPACITY.key -> Seq( + AlternateConfig("spark.scheduler.listenerbus.eventqueue.size", "2.3")) ) /** @@ -704,9 +708,9 @@ private[spark] object SparkConf extends Logging { * Looks for available deprecated keys for the given config option, and return the first * value available. */ - def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + def getDeprecatedConfig(key: String, conf: JMap[String, String]): Option[String] = { configsWithAlternatives.get(key).flatMap { alts => - alts.collectFirst { case alt if conf.contains(alt.key) => + alts.collectFirst { case alt if conf.containsKey(alt.key) => val value = conf.get(alt.key) if (alt.translation != null) alt.translation(value) else value } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 5d98a1185f053..392f9d56e7f51 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -19,6 +19,8 @@ package org.apache.spark.internal.config import java.util.{Map => JMap} +import org.apache.spark.SparkConf + /** * A source of configuration values. */ @@ -53,7 +55,7 @@ private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends Con override def get(key: String): Option[String] = { if (key.startsWith("spark.")) { - Option(conf.get(key)) + Option(conf.get(key)).orElse(SparkConf.getDeprecatedConfig(key, conf)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7be4d6b212d72..7a9072736b9aa 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -209,7 +209,6 @@ package object config { private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") - .withAlternative("spark.scheduler.listenerbus.eventqueue.size") .intConf .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) @@ -404,7 +403,6 @@ package object config { "affect both shuffle fetch and block manager remote block fetch. For users who " + "enabled external shuffle service, this feature can only be worked when external shuffle" + " service is newer than Spark 2.2.") - .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 0897891ee1758..c771eb4ee3ef5 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -26,6 +26,7 @@ import scala.util.{Random, Try} import com.esotericsoftware.kryo.Kryo +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{JavaSerializer, KryoRegistrator, KryoSerializer} @@ -248,6 +249,12 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst conf.set("spark.kryoserializer.buffer.mb", "1.1") assert(conf.getSizeAsKb("spark.kryoserializer.buffer") === 1100) + + conf.set("spark.history.fs.cleaner.maxAge.seconds", "42") + assert(conf.get(MAX_LOG_AGE_S) === 42L) + + conf.set("spark.scheduler.listenerbus.eventqueue.size", "84") + assert(conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY) === 84) } test("akka deprecated configs") { From 41c6f36018eb086477f21574aacd71616513bd8e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 01:42:05 +0100 Subject: [PATCH 270/712] [SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws ## What changes were proposed in this pull request? This PR changes `concat_ws` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `concat_ws` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19777 from kiszk/SPARK-22549. --- .../expressions/stringExpressions.scala | 100 +++++++++++++----- .../expressions/StringExpressionsSuite.scala | 13 +++ 2 files changed, 89 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d5bb7e95769e5..e6f55f476a921 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -137,13 +137,34 @@ case class ConcatWs(children: Seq[Expression]) if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.genCode(ctx)) - - val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String) null : ${eval.value}" - }.mkString(", ") - - ev.copy(evals.map(_.code).mkString("\n") + s""" - UTF8String ${ev.value} = UTF8String.concatWs($inputs); + val separator = evals.head + val strings = evals.tail + val numArgs = strings.length + val args = ctx.freshName("args") + + val inputs = strings.zipWithIndex.map { case (eval, index) => + if (eval.isNull != "true") { + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } else { + "" + } + } + val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions(inputs, "valueConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + } else { + inputs.mkString("\n") + } + ev.copy(s""" + UTF8String[] $args = new UTF8String[$numArgs]; + ${separator.code} + $codes + UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args); boolean ${ev.isNull} = ${ev.value} == null; """) } else { @@ -156,32 +177,63 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};") + if (eval.isNull == "true") { + "" + } else { + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + }) case _: ArrayType => val size = ctx.freshName("n") - (s""" - if (!${eval.isNull}) { - $varargNum += ${eval.value}.numElements(); - } - """, - s""" - if (!${eval.isNull}) { - final int $size = ${eval.value}.numElements(); - for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; - } + if (eval.isNull == "true") { + ("", "") + } else { + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.value}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.value}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + } + } + """) } - """) } }.unzip - ev.copy(evals.map(_.code).mkString("\n") + - s""" + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) + val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + "int", + { body => + s""" + int $varargNum = 0; + $body + return $varargNum; + """ + }, + _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) + val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + "int", + { body => + s""" + $body + return $idxInVararg; + """ + }, + _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + ev.copy( + s""" + $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxInVararg = 0; - ${varargCount.mkString("\n")} + $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; - ${varargBuild.mkString("\n")} + $varargBuilds UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); boolean ${ev.isNull} = ${ev.value} == null; """) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa9d5a0aa95e3..7ce43066b5aab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -80,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") { + val N = 5000 + val sepExpr = Literal.create("#", StringType) + val strings1 = (1 to N).map(x => s"s$x") + val inputsExpr1 = strings1.map(Literal.create(_, StringType)) + checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow) + + val strings2 = (1 to N).map(x => Seq(s"s$x")) + val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType))) + checkEvaluation( + ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow) + } + test("elt") { def testElt(result: String, n: java.lang.Integer, args: String*): Unit = { checkEvaluation( From 9d45e675e27278163241081789b06449ca220e43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 21 Nov 2017 09:36:37 +0100 Subject: [PATCH 271/712] [SPARK-22541][SQL] Explicitly claim that Python udfs can't be conditionally executed with short-curcuit evaluation ## What changes were proposed in this pull request? Besides conditional expressions such as `when` and `if`, users may want to conditionally execute python udfs by short-curcuit evaluation. We should also explicitly note that python udfs don't support this kind of conditional execution too. ## How was this patch tested? N/A, just document change. Author: Liang-Chi Hsieh Closes #19787 from viirya/SPARK-22541. --- python/pyspark/sql/functions.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b631e2041706f..4e0faddb1c0df 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2079,12 +2079,9 @@ def udf(f=None, returnType=StringType()): duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. - .. note:: The user-defined functions do not support conditional execution by using them with - SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no - matter the conditions are met or not. So the output is correct if the functions can be - correctly run on all rows without failure. If the functions can cause runtime failure on the - rows that do not satisfy the conditions, the suggested workaround is to incorporate the - condition logic into the functions. + .. note:: The user-defined functions do not support conditional expressions or short curcuiting + in boolean expressions and it ends up with being executed all internally. If the functions + can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object @@ -2194,12 +2191,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined function must be deterministic. - .. note:: The user-defined functions do not support conditional execution by using them with - SQL conditional expressions such as `when` or `if`. The functions still apply on all rows no - matter the conditions are met or not. So the output is correct if the functions can be - correctly run on all rows without failure. If the functions can cause runtime failure on the - rows that do not satisfy the conditions, the suggested workaround is to incorporate the - condition logic into the functions. + .. note:: The user-defined functions do not support conditional expressions or short curcuiting + in boolean expressions and it ends up with being executed all internally. If the functions + can fail on special rows, the workaround is to incorporate the condition into the functions. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From c9577148069d2215dc79cbf828a378591b4fba5d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 12:16:54 +0100 Subject: [PATCH 272/712] [SPARK-22508][SQL] Fix 64KB JVM bytecode limit problem with GenerateUnsafeRowJoiner.create() ## What changes were proposed in this pull request? This PR changes `GenerateUnsafeRowJoiner.create()` code generation to place generated code for statements to operate bitmap and offset into separated methods if these size could be large. ## How was this patch tested? Added a new test case into `GenerateUnsafeRowJoinerSuite` Author: Kazuaki Ishizaki Closes #19737 from kiszk/SPARK-22508. --- .../codegen/GenerateUnsafeRowJoiner.scala | 31 ++++++++++++++----- .../GenerateUnsafeRowJoinerSuite.scala | 5 +++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 6bc72a0d75c6d..be5f5a73b5d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform @@ -51,6 +54,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { + val ctx = new CodegenContext val offset = Platform.BYTE_ARRAY_OFFSET val getLong = "Platform.getLong" val putLong = "Platform.putLong" @@ -88,8 +92,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" } } - s"$putLong(buf, ${offset + i * 8}, $bits);" - }.mkString("\n") + s"$putLong(buf, ${offset + i * 8}, $bits);\n" + } + + val copyBitsets = ctx.splitExpressions( + expressions = copyBitset, + funcName = "copyBitsetFunc", + arguments = ("java.lang.Object", "obj1") :: ("long", "offset1") :: + ("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil) // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -150,11 +160,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s""" - |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); - """.stripMargin + s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" } - }.mkString("\n") + } + + val updateOffsets = ctx.splitExpressions( + expressions = updateOffset, + funcName = "copyBitsetFunc", + arguments = ("long", "numBytesVariableRow1") :: Nil) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -166,6 +179,8 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | private byte[] buf = new byte[64]; | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | + | ${ctx.declareAddedFunctions()} + | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset | // row2: ${schema2.size}, $bitset2Words words in bitset @@ -180,12 +195,12 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | - | $copyBitset + | $copyBitsets | $copyFixedLengthRow1 | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 - | $updateOffset + | $updateOffsets | | out.pointTo(buf, sizeInBytes); | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index 9f19745cefd20..f203f25ad10d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -66,6 +66,11 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { } } + test("SPARK-22508: GenerateUnsafeRowJoiner.create should not generate codes beyond 64KB") { + val N = 3000 + testConcatOnce(N, N, variable) + } + private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = { for (i <- 0 until 10) { testConcatOnce(numFields1, numFields2, candidateTypes) From 9bdff0bcd83e730aba8dc1253da24a905ba07ae3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 12:19:11 +0100 Subject: [PATCH 273/712] [SPARK-22550][SQL] Fix 64KB JVM bytecode limit problem with elt ## What changes were proposed in this pull request? This PR changes `elt` code generation to place generated code for expression for arguments into separated methods if these size could be large. This PR resolved the case of `elt` with a lot of argument ## How was this patch tested? Added new test cases into `StringExpressionsSuite` Author: Kazuaki Ishizaki Closes #19778 from kiszk/SPARK-22550. --- .../expressions/codegen/CodeGenerator.scala | 44 ++++++++++------- .../expressions/stringExpressions.scala | 48 +++++++++++++++---- .../expressions/StringExpressionsSuite.scala | 7 +++ 3 files changed, 73 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3dc3f8e4adac0..e02f125a93345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -790,23 +790,7 @@ class CodegenContext { returnType: String = "void", makeSplitFunction: String => String = identity, foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { - val blocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - var length = 0 - for (code <- expressions) { - // We can't know how many bytecode will be generated, so use the length of source code - // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should - // also not be too small, or it will have many function calls (for wide table), see the - // results in BenchmarkWideTable. - if (length > 1024) { - blocks += blockBuilder.toString() - blockBuilder.clear() - length = 0 - } - blockBuilder.append(code) - length += CodeFormatter.stripExtraNewLinesAndComments(code).length - } - blocks += blockBuilder.toString() + val blocks = buildCodeBlocks(expressions) if (blocks.length == 1) { // inline execution if only one block @@ -841,6 +825,32 @@ class CodegenContext { } } + /** + * Splits the generated code of expressions into multiple sequences of String + * based on a threshold of length of a String + * + * @param expressions the codes to evaluate expressions. + */ + def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + var length = 0 + for (code <- expressions) { + // We can't know how many bytecode will be generated, so use the length of source code + // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should + // also not be too small, or it will have many function calls (for wide table), see the + // results in BenchmarkWideTable. + if (length > 1024) { + blocks += blockBuilder.toString() + blockBuilder.clear() + length = 0 + } + blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length + } + blocks += blockBuilder.toString() + } + /** * Here we handle all the methods which have been added to the inner classes and * not to the outer class. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e6f55f476a921..360dd845f8d3a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -288,22 +288,52 @@ case class Elt(children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) val strings = stringExprs.map(_.genCode(ctx)) + val indexVal = ctx.freshName("index") + val stringVal = ctx.freshName("stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" case ${index + 1}: - ${ev.value} = ${eval.isNull} ? null : ${eval.value}; + ${eval.code} + $stringVal = ${eval.isNull} ? null : ${eval.value}; break; """ - }.mkString("\n") - val indexVal = ctx.freshName("index") - val stringArray = ctx.freshName("strings"); + } - ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s""" - final int $indexVal = ${index.value}; - UTF8String ${ev.value} = null; - switch ($indexVal) { - $assignStringValue + val cases = ctx.buildCodeBlocks(assignStringValue) + val codes = if (cases.length == 1) { + s""" + UTF8String $stringVal = null; + switch ($indexVal) { + ${cases.head} + } + """ + } else { + var prevFunc = "null" + for (c <- cases.reverse) { + val funcName = ctx.freshName("eltFunc") + val funcBody = s""" + private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) { + UTF8String $stringVal = null; + switch ($indexVal) { + $c + default: + return $prevFunc; + } + return $stringVal; + } + """ + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + prevFunc = s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)" } + s"UTF8String $stringVal = $prevFunc;" + } + + ev.copy( + s""" + ${index.code} + final int $indexVal = ${index.value}; + $codes + UTF8String ${ev.value} = $stringVal; final boolean ${ev.isNull} = ${ev.value} == null; """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 7ce43066b5aab..c761394756875 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -116,6 +116,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure) } + test("SPARK-22550: Elt should not generate codes beyond 64KB") { + val N = 10000 + val strings = (1 to N).map(x => s"s$x") + val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType)) + checkEvaluation(Elt(args), s"s$N") + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) From 96e947ed6c945bdbe71c308112308489668f19ac Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 21 Nov 2017 13:48:09 +0100 Subject: [PATCH 274/712] [SPARK-22569][SQL] Clean usage of addMutableState and splitExpressions ## What changes were proposed in this pull request? This PR is to clean the usage of addMutableState and splitExpressions 1. replace hardcoded type string to ctx.JAVA_BOOLEAN etc. 2. create a default value of the initCode for ctx.addMutableStats 3. Use named arguments when calling `splitExpressions ` ## How was this patch tested? The existing test cases Author: gatorsmile Closes #19790 from gatorsmile/codeClean. --- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 8 ++-- .../expressions/codegen/CodeGenerator.scala | 22 ++++++--- .../codegen/GenerateMutableProjection.scala | 2 +- .../expressions/complexTypeCreator.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +-- .../expressions/nullExpressions.scala | 10 +++-- .../expressions/objects/objects.scala | 26 +++++------ .../sql/catalyst/expressions/predicates.scala | 6 +-- .../expressions/randomExpressions.scala | 4 +- .../expressions/stringExpressions.scala | 45 +++++++++++-------- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 20 ++++----- .../aggregate/HashMapGenerator.scala | 4 +- .../execution/basicPhysicalOperators.scala | 8 ++-- .../columnar/GenerateColumnAccessor.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 6 +-- .../apache/spark/sql/execution/limit.scala | 4 +- 20 files changed, 104 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 84027b53dca27..821d784a01342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -67,8 +67,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, countTerm) + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8db7efdbb5dd4..4fa18d6b3209b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,7 +44,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addMutableState(ctx.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 72d5889d2f202..e5a1096bba713 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) def updateEval(eval: ExprCode): String = { s""" ${eval.code} @@ -668,8 +668,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) def updateEval(eval: ExprCode): String = { s""" ${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e02f125a93345..78617194e47d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -157,7 +157,19 @@ class CodegenContext { val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { + /** + * Add a mutable state as a field to the generated class. c.f. the comments above. + * + * @param javaType Java type of the field. Note that short names can be used for some types, + * e.g. InternalRow, UnsafeRow, UnsafeArrayData, etc. Other types will have to + * specify the fully-qualified Java type name. See the code in doCompile() for + * the list of default imports available. + * Also, generic type arguments are accepted but ignored. + * @param variableName Name of the field. + * @param initCode The statement(s) to put into the init() method to initialize this field. + * If left blank, the field will be default-initialized. + */ + def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = { mutableStates += ((javaType, variableName, initCode)) } @@ -191,7 +203,7 @@ class CodegenContext { val initCodes = mutableStates.distinct.map(_._3 + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(initCodes, "init", Nil) + splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) } /** @@ -769,7 +781,7 @@ class CodegenContext { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } - splitExpressions(expressions, "apply", ("InternalRow", row) :: Nil) + splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil) } /** @@ -931,7 +943,7 @@ class CodegenContext { dataType: DataType, baseFuncName: String): (String, String, String) = { val globalIsNull = freshName("isNull") - addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;") val globalValue = freshName("value") addMutableState(javaType(dataType), globalValue, s"$globalValue = ${defaultValue(dataType)};") @@ -1038,7 +1050,7 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b5429fade53cf..802e8bdb1ca33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,7 +63,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"$isNull = true;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4b6574a31424e..2a00d57ee1300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -120,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, "") + ctx.addMutableState("UnsafeArrayData", arrayDataName) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index eb3c49f5cf30e..9e0786e367911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -277,7 +277,7 @@ abstract class HashExpression[E] extends Expression { } }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value) ev.copy(code = s""" ${ev.value} = $seed; $childrenHash""") @@ -616,8 +616,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { s"\n$childHash = 0;" }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") - ctx.addMutableState("int", childHash, s"$childHash = 0;") + ctx.addMutableState(ctx.javaType(dataType), ev.value) + ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;") ev.copy(code = s""" ${ev.value} = $seed; $childrenHash""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 4aeab2c3ad0a8..5eaf3f2202776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState("boolean", ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) val evals = children.map { e => val eval = e.genCode(ctx) @@ -385,8 +385,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { evals.mkString("\n") } else { - ctx.splitExpressions(evals, "atLeastNNonNulls", - ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, + ctx.splitExpressions( + expressions = evals, + funcName = "atLeastNNonNulls", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, returnType = "int", makeSplitFunction = { body => s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f2eee991c9865..006d37f38d6c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -63,14 +63,14 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState("boolean", resultIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull) resultIsNull } else { "false" } val argValues = arguments.map { e => val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + ctx.addMutableState(ctx.javaType(e.dataType), argValue) argValue } @@ -548,7 +548,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, "") + ctx.addMutableState(elementJavaType, loopValue) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -644,7 +644,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -808,10 +808,10 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + ctx.addMutableState(keyElementJavaType, keyLoopValue) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -844,7 +844,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -994,8 +994,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState(valueElementJavaType, value, "") + ctx.addMutableState(keyElementJavaType, key) + ctx.addMutableState(valueElementJavaType, value) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1031,14 +1031,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState("boolean", keyIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull) s"$valueIsNull = $value == null;" } else { "" @@ -1106,7 +1106,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, "") + ctx.addMutableState("Object[]", values) val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) @@ -1244,7 +1244,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") + ctx.addMutableState(beanInstanceJavaType, javaBeanInstance) val initialize = setters.map { case (setterMethod, fieldValue) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5d75c6004bfe4..c0084af320689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -236,8 +236,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState("boolean", ev.value, "") - ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val valueArg = ctx.freshName("valueArg") val listCode = listGen.map(x => s""" @@ -253,7 +253,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { """) val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil - ctx.splitExpressions(listCode, "valueIn", args) + ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args) } else { listCode.mkString("\n") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 97051769cbf72..b4aefe6cff73e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -79,7 +79,7 @@ case class Rand(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + ctx.addMutableState(className, rngTerm) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" @@ -114,7 +114,7 @@ case class Randn(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + ctx.addMutableState(className, rngTerm) ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 360dd845f8d3a..1c599af2a01d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -74,8 +74,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas """ } val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions(inputs, "valueConcat", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcat", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) } else { inputs.mkString("\n") } @@ -155,8 +157,10 @@ case class ConcatWs(children: Seq[Expression]) } } val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions(inputs, "valueConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcatWs", + arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) } else { inputs.mkString("\n") } @@ -205,27 +209,30 @@ case class ConcatWs(children: Seq[Expression]) }.unzip val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) - val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - "int", - { body => + val varargCounts = ctx.splitExpressions( + expressions = varargCount, + funcName = "varargCountsConcatWs", + arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil, + returnType = "int", + makeSplitFunction = body => s""" int $varargNum = 0; $body return $varargNum; - """ - }, - _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) - val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs", - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, - "int", - { body => + """, + foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) + val varargBuilds = ctx.splitExpressions( + expressions = varargBuild, + funcName = "varargBuildsConcatWs", + arguments = + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + returnType = "int", + makeSplitFunction = body => s""" $body return $idxInVararg; - """ - }, - _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + """, + foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) ev.copy( s""" $codes @@ -2059,7 +2066,7 @@ case class FormatNumber(x: Expression, d: Expression) val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;") ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1925bad8c3545..a9bfb634fbdea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -76,14 +76,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;") val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.freshName("batch") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") val idx = ctx.freshName("batchIdx") - ctx.addMutableState("int", idx, s"$idx = 0;") + ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(colVars.size)(classOf[ColumnVector].getName)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 21765cdbd94cf..c0e21343ae623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -134,7 +134,7 @@ case class SortExec( override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") - ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 51f7c9e22b902..19c793e45a57d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -186,8 +186,8 @@ case class HashAggregateExec( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) + ctx.addMutableState(ctx.javaType(e.dataType), value) // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -565,7 +565,7 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -596,27 +596,27 @@ case class HashAggregateExec( s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( "java.util.Iterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap) } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap) } } // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, "") + ctx.addMutableState(hashMapClassName, hashMapTerm) sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -774,7 +774,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 90deb20e97244..85b4529501ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -48,8 +48,8 @@ abstract class HashMapGenerator( initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) + ctx.addMutableState(ctx.javaType(e.dataType), value) val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3c7daa0a45844..f205bdf3da709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,9 +368,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.freshName("initRange") - ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;") val number = ctx.freshName("number") - ctx.addMutableState("long", number, s"$number = 0L;") + ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -391,11 +391,11 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Once number == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;") // How many values should still be generated by this range operator. val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ae600c1ffae8e..ff5dd707f0b38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -89,7 +89,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, "") + ctx.addMutableState(accessorCls, accessorName) val createCode = dt match { case t if ctx.isPrimitiveType(dt) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index cf7885f80d9fe..9c08ec71c1fde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -423,7 +423,7 @@ case class SortMergeJoinExec( private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow, "") + ctx.addMutableState("InternalRow", leftRow) val rightRow = ctx.freshName("rightRow") ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") @@ -519,10 +519,10 @@ case class SortMergeJoinExec( val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value, "") + ctx.addMutableState(ctx.javaType(a.dataType), value) if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) val code = s""" |$isNull = $leftRow.isNullAt($i); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 13da4b26a5dcb..a8556f6ba107a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -72,7 +72,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;") ctx.addNewFunction("stopEarly", s""" @Override @@ -81,7 +81,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") s""" | if ($countTerm < $limit) { | $countTerm += 1; From 5855b5c03e831af997dab9f2023792c9b8d2676c Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 21 Nov 2017 07:25:56 -0600 Subject: [PATCH 275/712] [MINOR][DOC] The left navigation bar should be fixed with respect to scrolling. ## What changes were proposed in this pull request? A minor CSS style change to make Left navigation bar stay fixed with respect to scrolling, it improves usability of the docs. ## How was this patch tested? It was tested on both, firefox and chrome. ### Before ![a2](https://user-images.githubusercontent.com/992952/33004206-6acf9fc0-cde5-11e7-9070-02f26f7899b0.gif) ### After ![a1](https://user-images.githubusercontent.com/992952/33004205-69b27798-cde5-11e7-8002-509b29786b37.gif) Author: Prashant Sharma Closes #19785 from ScrapCodes/doc/css. --- docs/css/main.css | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/css/main.css b/docs/css/main.css index 175e8004fca0e..7f1e99bf67224 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -195,7 +195,7 @@ a.anchorjs-link:hover { text-decoration: none; } margin-top: 0px; width: 210px; float: left; - position: absolute; + position: fixed; } .left-menu { @@ -286,4 +286,4 @@ label[for="nav-trigger"]:hover { margin-left: -215px; } -} \ No newline at end of file +} From 2d868d93987ea1757cc66cdfb534bc49794eb0d0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 21 Nov 2017 10:53:53 -0800 Subject: [PATCH 276/712] [SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python API ## What changes were proposed in this pull request? Add python api for VectorIndexerModel support handle unseen categories via handleInvalid. ## How was this patch tested? doctest added. Author: WeichenXu Closes #19753 from WeichenXu123/vector_indexer_invalid_py. --- .../spark/ml/feature/VectorIndexer.scala | 7 +++-- python/pyspark/ml/feature.py | 30 ++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 3403ec4259b86..e6ec4e2e36ff0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Options are: * 'skip': filter out rows with invalid data. * 'error': throw an error. - * 'keep': put invalid data in a special additional bucket, at index numCategories. + * 'keep': put invalid data in a special additional bucket, at index of the number of + * categories of the feature. * Default value: "error" * @group param */ @@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen labels or NULL values). " + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + - "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + "or 'keep' (put invalid data in a special additional bucket, at index of the " + + "number of categories of the feature).", ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) @@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. * - Specify certain features to not index, either via a parameter or via existing metadata. * - Add warning if a categorical feature has only 1 category. - * - Add option for allowing unknown categories. */ @Since("1.4.0") class VectorIndexer @Since("1.4.0") ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 232ae3ef41166..608f2a5715497 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2490,7 +2490,8 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ Class for indexing categorical feature columns in a dataset of `Vector`. @@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja do not recompute. - Specify certain features to not index, either via a parameter or via existing metadata. - Add warning if a categorical feature has only 1 category. - - Add option for allowing unknown categories. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), @@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja True >>> loadedModel.categoryMaps == model.categoryMaps True + >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"]) + >>> indexer.getHandleInvalid() + 'error' + >>> model3 = indexer.setHandleInvalid("skip").fit(df) + >>> model3.transform(dfWithInvalid).count() + 0 + >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df) + >>> model4.transform(dfWithInvalid).head().indexed + DenseVector([2.0, 1.0]) .. versionadded:: 1.4.0 """ @@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja "(>= 2). If a feature is found to have > maxCategories values, then " + "it is declared continuous.", typeConverter=TypeConverters.toInt) + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " + + "(unseen labels or NULL values). Options are 'skip' (filter out " + + "rows with invalid data), 'error' (throw an error), or 'keep' (put " + + "invalid data in a special additional bucket, at index of the number " + + "of categories of the feature).", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, maxCategories=20, inputCol=None, outputCol=None): + def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, maxCategories=20, inputCol=None, outputCol=None) + __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self._setDefault(maxCategories=20) + self._setDefault(maxCategories=20, handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, maxCategories=20, inputCol=None, outputCol=None): + def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, maxCategories=20, inputCol=None, outputCol=None) + setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this VectorIndexer. """ kwargs = self._input_kwargs From 6d7ebf2f9fbd043813738005a23c57a77eba6f47 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 21 Nov 2017 20:53:38 +0100 Subject: [PATCH 277/712] [SPARK-22165][SQL] Fixes type conflicts between double, long, decimals, dates and timestamps in partition column ## What changes were proposed in this pull request? This PR proposes to add a rule that re-uses `TypeCoercion.findWiderCommonType` when resolving type conflicts in partition values. Currently, this uses numeric precedence-like comparison; therefore, it looks introducing failures for type conflicts between timestamps, dates and decimals, please see: ```scala private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) ... literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) ``` The codes below: ```scala val df = Seq((1, "2015-01-01"), (2, "2016-01-01 00:00:00")).toDF("i", "ts") df.write.format("parquet").partitionBy("ts").save("/tmp/foo") spark.read.load("/tmp/foo").printSchema() val df = Seq((1, "1"), (2, "1" * 30)).toDF("i", "decimal") df.write.format("parquet").partitionBy("decimal").save("/tmp/bar") spark.read.load("/tmp/bar").printSchema() ``` produces output as below: **Before** ``` root |-- i: integer (nullable = true) |-- ts: date (nullable = true) root |-- i: integer (nullable = true) |-- decimal: integer (nullable = true) ``` **After** ``` root |-- i: integer (nullable = true) |-- ts: timestamp (nullable = true) root |-- i: integer (nullable = true) |-- decimal: decimal(30,0) (nullable = true) ``` ### Type coercion table: This PR proposes the type conflict resolusion as below: **Before** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`StringType`|`IntegerType`|`LongType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`IntegerType`|`DoubleType`|`IntegerType`|`IntegerType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`LongType`|`DoubleType`|`LongType`|`LongType`|`StringType`| |**`DecimalType(38,0)`**|`StringType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`DoubleType`|`StringType`| |**`DateType`**|`StringType`|`IntegerType`|`LongType`|`DateType`|`DoubleType`|`DateType`|`DateType`|`StringType`| |**`TimestampType`**|`StringType`|`IntegerType`|`LongType`|`TimestampType`|`DoubleType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| **After** |InputA \ InputB|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |------------------------|----------|----------|----------|----------|----------|----------|----------|----------| |**`NullType`**|`NullType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`DateType`|`TimestampType`|`StringType`| |**`IntegerType`**|`IntegerType`|`IntegerType`|`LongType`|`DecimalType(38,0)`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`LongType`**|`LongType`|`LongType`|`LongType`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DecimalType(38,0)`**|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`DecimalType(38,0)`|`StringType`|`StringType`|`StringType`|`StringType`| |**`DoubleType`**|`DoubleType`|`DoubleType`|`StringType`|`StringType`|`DoubleType`|`StringType`|`StringType`|`StringType`| |**`DateType`**|`DateType`|`StringType`|`StringType`|`StringType`|`StringType`|`DateType`|`TimestampType`|`StringType`| |**`TimestampType`**|`TimestampType`|`StringType`|`StringType`|`StringType`|`StringType`|`TimestampType`|`TimestampType`|`StringType`| |**`StringType`**|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`|`StringType`| This was produced by: ```scala test("Print out chart") { val supportedTypes: Seq[DataType] = Seq( NullType, IntegerType, LongType, DecimalType(38, 0), DoubleType, DateType, TimestampType, StringType) // Old type conflict resolution: val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) def oldResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { val topType = dataTypes.maxBy(upCastingOrder.indexOf(_)) if (topType == NullType) StringType else topType } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => oldResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } // New type conflict resolution: def newResolveTypeConflicts(dataTypes: Seq[DataType]): DataType = { dataTypes.fold[DataType](NullType)(findWiderTypeForPartitionColumn) } println(s"|InputA \\ InputB|${supportedTypes.map(dt => s"`${dt.toString}`").mkString("|")}|") println(s"|------------------------|${supportedTypes.map(_ => "----------").mkString("|")}|") supportedTypes.foreach { inputA => val types = supportedTypes.map(inputB => newResolveTypeConflicts(Seq(inputA, inputB))) println(s"|**`$inputA`**|${types.map(dt => s"`${dt.toString}`").mkString("|")}|") } } ``` ## How was this patch tested? Unit tests added in `ParquetPartitionDiscoverySuite`. Author: hyukjinkwon Closes #19389 from HyukjinKwon/partition-type-coercion. --- docs/sql-programming-guide.md | 139 ++++++++++++++++++ .../sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../datasources/PartitioningUtils.scala | 60 +++++--- .../ParquetPartitionDiscoverySuite.scala | 57 ++++++- 4 files changed, 235 insertions(+), 23 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 686fcb159d09d..5f9821378b271 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1577,6 +1577,145 @@ options. - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. + - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + InputA \ InputB + + NullType + + IntegerType + + LongType + + DecimalType(38,0)* + + DoubleType + + DateType + + TimestampType + + StringType +
    + NullType + NullTypeIntegerTypeLongTypeDecimalType(38,0)DoubleTypeDateTypeTimestampTypeStringType
    + IntegerType + IntegerTypeIntegerTypeLongTypeDecimalType(38,0)DoubleTypeStringTypeStringTypeStringType
    + LongType + LongTypeLongTypeLongTypeDecimalType(38,0)StringTypeStringTypeStringTypeStringType
    + DecimalType(38,0)* + DecimalType(38,0)DecimalType(38,0)DecimalType(38,0)DecimalType(38,0)StringTypeStringTypeStringTypeStringType
    + DoubleType + DoubleTypeDoubleTypeStringTypeStringTypeDoubleTypeStringTypeStringTypeStringType
    + DateType + DateTypeStringTypeStringTypeStringTypeStringTypeDateTypeTimestampTypeStringType
    + TimestampType + TimestampTypeStringTypeStringTypeStringTypeStringTypeTimestampTypeTimestampTypeStringType
    + StringType + StringTypeStringTypeStringTypeStringTypeStringTypeStringTypeStringTypeStringType
    + + Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 074eda56199e8..28be955e08a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -155,7 +155,7 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1c00c9ebb4144..472bf82d3604d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -309,13 +309,8 @@ object PartitioningUtils { } /** - * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- - * casting order is: - * {{{ - * NullType -> - * IntegerType -> LongType -> - * DoubleType -> StringType - * }}} + * Resolves possible type conflicts between partitions by up-casting "lower" types using + * [[findWiderTypeForPartitionColumn]]. */ def resolvePartitions( pathsWithPartitionValues: Seq[(Path, PartitionValues)], @@ -372,11 +367,31 @@ object PartitioningUtils { suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") } + // scalastyle:off line.size.limit /** - * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]] + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[NullType]], [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]] * [[TimestampType]], and [[StringType]]. + * + * When resolving conflicts, it follows the table below: + * + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * | InputA \ InputB | NullType | IntegerType | LongType | DecimalType(38,0)* | DoubleType | DateType | TimestampType | StringType | + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * | NullType | NullType | IntegerType | LongType | DecimalType(38,0) | DoubleType | DateType | TimestampType | StringType | + * | IntegerType | IntegerType | IntegerType | LongType | DecimalType(38,0) | DoubleType | StringType | StringType | StringType | + * | LongType | LongType | LongType | LongType | DecimalType(38,0) | StringType | StringType | StringType | StringType | + * | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | StringType | StringType | StringType | StringType | + * | DoubleType | DoubleType | DoubleType | StringType | StringType | DoubleType | StringType | StringType | StringType | + * | DateType | DateType | StringType | StringType | StringType | StringType | DateType | TimestampType | StringType | + * | TimestampType | TimestampType | StringType | StringType | StringType | StringType | TimestampType | TimestampType | StringType | + * | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | + * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+ + * Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other + * combinations of scales and precisions because currently we only infer decimal type like + * `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. */ + // scalastyle:on line.size.limit private[datasources] def inferPartitionColumnValue( raw: String, typeInference: Boolean, @@ -427,9 +442,6 @@ object PartitioningUtils { } } - private val upCastingOrder: Seq[DataType] = - Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) - def validatePartitionColumn( schema: StructType, partitionColumns: Seq[String], @@ -468,18 +480,26 @@ object PartitioningUtils { } /** - * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" - * types. + * Given a collection of [[Literal]]s, resolves possible type conflicts by + * [[findWiderTypeForPartitionColumn]]. */ private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = { - val desiredType = { - val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) - // Falls back to string if all values of this column are null or empty string - if (topType == NullType) StringType else topType - } + val litTypes = literals.map(_.dataType) + val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn) literals.map { case l @ Literal(_, dataType) => Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType) } } + + /** + * Type widening rule for partition column types. It is similar to + * [[TypeCoercion.findWiderTypeForTwo]] but the main difference is that here we disallow + * precision loss when widening double/long and decimal, and fall back to string. + */ + private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType + case (DoubleType, LongType) | (LongType, DoubleType) => StringType + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f79b92b804c70..d4902641e335f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -249,6 +249,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha true, rootPaths, timeZoneId) + assert(actualSpec.partitionColumns === spec.partitionColumns) + assert(actualSpec.partitions.length === spec.partitions.length) + actualSpec.partitions.zip(spec.partitions).foreach { case (actual, expected) => + assert(actual === expected) + } assert(actualSpec === spec) } @@ -314,7 +319,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha PartitionSpec( StructType(Seq( StructField("a", DoubleType), - StructField("b", StringType))), + StructField("b", NullType))), Seq( Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), Partition(InternalRow(10.5, null), @@ -324,6 +329,32 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha s"hdfs://host:9000/path1", s"hdfs://host:9000/path2"), PartitionSpec.emptySpec) + + // The cases below check the resolution for type conflicts. + val t1 = Timestamp.valueOf("2014-01-01 00:00:00.0").getTime * 1000 + val t2 = Timestamp.valueOf("2014-01-01 00:01:00.0").getTime * 1000 + // Values in column 'a' are inferred as null, date and timestamp each, and timestamp is set + // as a common type. + // Values in column 'b' are inferred as integer, decimal(22, 0) and null, and decimal(22, 0) + // is set as a common type. + check(Seq( + s"hdfs://host:9000/path/a=$defaultPartitionName/b=0", + s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111", + s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", TimestampType), + StructField("b", DecimalType(22, 0)))), + Seq( + Partition( + InternalRow(null, Decimal(0)), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=0"), + Partition( + InternalRow(t1, Decimal(s"${Long.MaxValue}111")), + s"hdfs://host:9000/path/a=2014-01-01/b=${Long.MaxValue}111"), + Partition( + InternalRow(t2, null), + s"hdfs://host:9000/path/a=2014-01-01 00%3A01%3A00.0/b=$defaultPartitionName")))) } test("parse partitions with type inference disabled") { @@ -395,7 +426,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha PartitionSpec( StructType(Seq( StructField("a", StringType), - StructField("b", StringType))), + StructField("b", NullType))), Seq( Partition(InternalRow(UTF8String.fromString("10"), null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), @@ -1067,4 +1098,26 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer(spark.read.load(path.getAbsolutePath), df) } } + + test("Resolve type conflicts - decimals, dates and timestamps in partition column") { + withTempPath { path => + val df = Seq((1, "2014-01-01"), (2, "2016-01-01"), (3, "2015-01-01 00:01:00")).toDF("i", "ts") + df.write.format("parquet").partitionBy("ts").save(path.getAbsolutePath) + checkAnswer( + spark.read.load(path.getAbsolutePath), + Row(1, Timestamp.valueOf("2014-01-01 00:00:00")) :: + Row(2, Timestamp.valueOf("2016-01-01 00:00:00")) :: + Row(3, Timestamp.valueOf("2015-01-01 00:01:00")) :: Nil) + } + + withTempPath { path => + val df = Seq((1, "1"), (2, "3"), (3, "2" * 30)).toDF("i", "decimal") + df.write.format("parquet").partitionBy("decimal").save(path.getAbsolutePath) + checkAnswer( + spark.read.load(path.getAbsolutePath), + Row(1, BigDecimal("1")) :: + Row(2, BigDecimal("3")) :: + Row(3, BigDecimal("2" * 30)) :: Nil) + } + } } From b96f61b6b262836e6be3f7657a3fe136d58b4dfe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 21 Nov 2017 20:55:24 +0100 Subject: [PATCH 278/712] [SPARK-22475][SQL] show histogram in DESC COLUMN command ## What changes were proposed in this pull request? Added the histogram representation to the output of the `DESCRIBE EXTENDED table_name column_name` command. ## How was this patch tested? Modified SQL UT and checked output Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #19774 from mgaido91/SPARK-22475. --- .../spark/sql/execution/command/tables.scala | 17 +++++ .../inputs/describe-table-column.sql | 10 +++ .../results/describe-table-column.sql.out | 74 +++++++++++++++++-- 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 95f16b0f4baea..c9f6e571ddab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.Histogram import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -689,9 +690,25 @@ case class DescribeColumnCommand( buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + val histDesc = for { + c <- cs + hist <- c.histogram + } yield histogramDescription(hist) + buffer ++= histDesc.getOrElse(Seq(Row("histogram", "NULL"))) } buffer } + + private def histogramDescription(histogram: Histogram): Seq[Row] = { + val header = Row("histogram", + s"height: ${histogram.height}, num_of_bins: ${histogram.bins.length}") + val bins = histogram.bins.zipWithIndex.map { + case (bin, index) => + Row(s"bin_$index", + s"lower_bound: ${bin.lo}, upper_bound: ${bin.hi}, distinct_count: ${bin.ndv}") + } + header +: bins + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql index a6ddcd999bf9b..2d180d118da7a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql @@ -34,6 +34,16 @@ DESC FORMATTED desc_complex_col_table col; -- Describe a nested column DESC FORMATTED desc_complex_col_table col.x; +-- Test output for histogram statistics +SET spark.sql.statistics.histogram.enabled=true; +SET spark.sql.statistics.histogram.numBins=2; + +INSERT INTO desc_col_table values 1, 2, 3, 4; + +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key; + +DESC EXTENDED desc_col_table key; + DROP VIEW desc_col_temp_view; DROP TABLE desc_col_table; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out index 30d0a2dc5a3f7..6ef8af6574e98 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 23 -- !query 0 @@ -34,6 +34,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 3 @@ -50,6 +51,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 4 @@ -66,6 +68,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 5 @@ -117,6 +120,7 @@ num_nulls 0 distinct_count 0 avg_col_len 4 max_col_len 4 +histogram NULL -- !query 10 @@ -133,6 +137,7 @@ num_nulls 0 distinct_count 0 avg_col_len 4 max_col_len 4 +histogram NULL -- !query 11 @@ -157,6 +162,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 13 @@ -173,6 +179,7 @@ num_nulls NULL distinct_count NULL avg_col_len NULL max_col_len NULL +histogram NULL -- !query 14 @@ -185,24 +192,75 @@ DESC TABLE COLUMN command does not support nested data types: col.x; -- !query 15 -DROP VIEW desc_col_temp_view +SET spark.sql.statistics.histogram.enabled=true -- !query 15 schema -struct<> +struct -- !query 15 output - +spark.sql.statistics.histogram.enabled true -- !query 16 -DROP TABLE desc_col_table +SET spark.sql.statistics.histogram.numBins=2 -- !query 16 schema -struct<> +struct -- !query 16 output - +spark.sql.statistics.histogram.numBins 2 -- !query 17 -DROP TABLE desc_complex_col_table +INSERT INTO desc_col_table values 1, 2, 3, 4 -- !query 17 schema struct<> -- !query 17 output + + +-- !query 18 +ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +DESC EXTENDED desc_col_table key +-- !query 19 schema +struct +-- !query 19 output +col_name key +data_type int +comment column_comment +min 1 +max 4 +num_nulls 0 +distinct_count 4 +avg_col_len 4 +max_col_len 4 +histogram height: 2.0, num_of_bins: 2 +bin_0 lower_bound: 1.0, upper_bound: 2.0, distinct_count: 2 +bin_1 lower_bound: 2.0, upper_bound: 4.0, distinct_count: 2 + + +-- !query 20 +DROP VIEW desc_col_temp_view +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP TABLE desc_col_table +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP TABLE desc_complex_col_table +-- !query 22 schema +struct<> +-- !query 22 output + From ac10171bea2fc027d6691393b385b3fc0ef3293d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Nov 2017 22:24:43 +0100 Subject: [PATCH 279/712] [SPARK-22500][SQL] Fix 64KB JVM bytecode limit problem with cast ## What changes were proposed in this pull request? This PR changes `cast` code generation to place generated code for expression for fields of a structure into separated methods if these size could be large. ## How was this patch tested? Added new test cases into `CastSuite` Author: Kazuaki Ishizaki Closes #19730 from kiszk/SPARK-22500. --- .../spark/sql/catalyst/expressions/Cast.scala | 12 ++++++++++-- .../sql/catalyst/expressions/CastSuite.scala | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bc809f559d586..12baddf1bf7ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1039,13 +1039,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } } """ - }.mkString("\n") + } + val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions( + expressions = fieldsEvalCode, + funcName = "castStruct", + arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil) + } else { + fieldsEvalCode.mkString("\n") + } (c, evPrim, evNull) => s""" final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; - $fieldsEvalCode + $fieldsEvalCodes $evPrim = $result; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a7ffa884d2286..84bd8b2f91e4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -827,4 +827,22 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(input, from), to), input) } + + test("SPARK-22500: cast for struct should not generate codes beyond 64KB") { + val N = 250 + + val fromInner = new StructType( + (1 to N).map(i => StructField(s"s$i", DoubleType)).toArray) + val toInner = new StructType( + (1 to N).map(i => StructField(s"i$i", IntegerType)).toArray) + val inputInner = Row.fromSeq((1 to N).map(i => i + 0.5)) + val outputInner = Row.fromSeq((1 to N)) + val fromOuter = new StructType( + (1 to N).map(i => StructField(s"s$i", fromInner)).toArray) + val toOuter = new StructType( + (1 to N).map(i => StructField(s"s$i", toInner)).toArray) + val inputOuter = Row.fromSeq((1 to N).map(_ => inputInner)) + val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner)) + checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter) + } } From 881c5c807304a305ef96e805d51afbde097f7f4f Mon Sep 17 00:00:00 2001 From: Jia Li Date: Tue, 21 Nov 2017 17:30:02 -0800 Subject: [PATCH 280/712] [SPARK-22548][SQL] Incorrect nested AND expression pushed down to JDBC data source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Let’s say I have a nested AND expression shown below and p2 can not be pushed down, (p1 AND p2) OR p3 In current Spark code, during data source filter translation, (p1 AND p2) is returned as p1 only and p2 is simply lost. This issue occurs with JDBC data source and is similar to [SPARK-12218](https://github.com/apache/spark/pull/10362) for Parquet. When we have AND nested below another expression, we should either push both legs or nothing. Note that: - The current Spark code will always split conjunctive predicate before it determines if a predicate can be pushed down or not - If I have (p1 AND p2) AND p3, it will be split into p1, p2, p3. There won't be nested AND expression. - The current Spark code logic for OR is OK. It either pushes both legs or nothing. The same translation method is also called by Data Source V2. ## How was this patch tested? Added new unit test cases to JDBCSuite gatorsmile Author: Jia Li Closes #19776 from jliwork/spark-22548. --- .../datasources/DataSourceStrategy.scala | 14 +- .../datasources/DataSourceStrategySuite.scala | 231 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +- 3 files changed, 250 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 04d6f3f56eb02..400f2e03165b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -497,7 +497,19 @@ object DataSourceStrategy { Some(sources.IsNotNull(a.name)) case expressions.And(left, right) => - (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And) + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilter(left) + rightFilter <- translateFilter(right) + } yield sources.And(leftFilter, rightFilter) case expressions.Or(left, right) => for { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala new file mode 100644 index 0000000000000..f20aded169e44 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.sources +import org.apache.spark.sql.test.SharedSQLContext + +class DataSourceStrategySuite extends PlanTest with SharedSQLContext { + + test("translate simple expression") { + val attrInt = 'cint.int + val attrStr = 'cstr.string + + testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1))) + testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1))) + + testTranslateFilter(EqualNullSafe(attrStr, Literal(null)), + Some(sources.EqualNullSafe("cstr", null))) + testTranslateFilter(EqualNullSafe(Literal(null), attrStr), + Some(sources.EqualNullSafe("cstr", null))) + + testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1))) + testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1))) + + testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1))) + testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1))) + + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1))) + + testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) + testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) + + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + + testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) + testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) + + // cint > 1 AND cint < 10 + testTranslateFilter(And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10)), + Some(sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)))) + + // cint >= 8 OR cint <= 2 + testTranslateFilter(Or( + GreaterThanOrEqual(attrInt, 8), + LessThanOrEqual(attrInt, 2)), + Some(sources.Or( + sources.GreaterThanOrEqual("cint", 8), + sources.LessThanOrEqual("cint", 2)))) + + testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), + Some(sources.Not(sources.GreaterThanOrEqual("cint", 8)))) + + testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a"))) + + testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a"))) + + testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a"))) + } + + test("translate complex expression") { + val attrInt = 'cint.int + + // ABS(cint) - 2 <= 1 + testTranslateFilter(LessThanOrEqual( + // Expressions are not supported + // Functions such as 'Abs' are not supported + Subtract(Abs(attrInt), 2), 1), None) + + // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), + Some(sources.Or( + sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.And( + sources.GreaterThan("cint", 50), + sources.LessThan("cint", 100))))) + + // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source + // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) + testTranslateFilter(Or( + And( + GreaterThan(attrInt, 1), + // Functions such as 'Abs' are not supported + LessThan(Abs(attrInt), 10) + ), + And( + GreaterThan(attrInt, 50), + LessThan(attrInt, 100))), None) + + // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100)) + testTranslateFilter(Not(And( + Or( + LessThanOrEqual(attrInt, 1), + // Functions such as 'Abs' are not supported + GreaterThanOrEqual(Abs(attrInt), 10) + ), + Or( + LessThanOrEqual(attrInt, 50), + GreaterThanOrEqual(attrInt, 100)))), None) + + // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + EqualTo(attrInt, 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), + Some(sources.Or( + sources.Or( + sources.EqualTo("cint", 1), + sources.EqualTo("cint", 10)), + sources.Or( + sources.GreaterThan("cint", 0), + sources.LessThan("cint", -10))))) + + // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) + testTranslateFilter(Or( + Or( + EqualTo(attrInt, 1), + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 10) + ), + Or( + GreaterThan(attrInt, 0), + LessThan(attrInt, -10))), None) + + // In end-to-end testing, conjunctive predicate should has been split + // before reaching DataSourceStrategy.translateFilter. + // This is for UT purpose to test each [[case]]. + // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(sources.And( + sources.And( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.And( + sources.EqualTo("cint", 6), + sources.IsNotNull("cint"))))) + + // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) + testTranslateFilter(And( + And( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + And( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + EqualTo(attrInt, 6), + IsNotNull(attrInt))), + Some(sources.And( + sources.Or( + sources.GreaterThan("cint", 1), + sources.LessThan("cint", 10)), + sources.Or( + sources.EqualTo("cint", 6), + sources.IsNotNull("cint"))))) + + // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + testTranslateFilter(And( + Or( + GreaterThan(attrInt, 1), + LessThan(attrInt, 10) + ), + Or( + // Functions such as 'Abs' are not supported + EqualTo(Abs(attrInt), 6), + IsNotNull(attrInt))), None) + } + + /** + * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] + * then verify against the given [[sources.Filter]]. + */ + def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { + assertResult(result) { + DataSourceStrategy.translateFilter(catalystFilter) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 88a5f618d604d..61571bccdcb51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -294,10 +294,13 @@ class JDBCSuite extends SparkFunSuite // This is a test to reflect discussion in SPARK-12218. // The older versions of spark have this kind of bugs in parquet data source. - val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')") - val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") + val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") assert(df1.collect.toSet === Set(Row("mary", 2))) - assert(df2.collect.toSet === Set(Row("mary", 2))) + + // SPARK-22548: Incorrect nested AND expression pushed down to JDBC data source + val df2 = sql("SELECT * FROM foobar " + + "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") + assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) def checkNotPushdown(df: DataFrame): DataFrame = { val parentPlan = df.queryExecution.executedPlan From e0d7665cec1e6954d640f422c79ebba4c273be7d Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 21 Nov 2017 22:31:46 -0800 Subject: [PATCH 281/712] [SPARK-17920][SPARK-19580][SPARK-19878][SQL] Support writing to Hive table which uses Avro schema url 'avro.schema.url' ## What changes were proposed in this pull request? SPARK-19580 Support for avro.schema.url while writing to hive table SPARK-19878 Add hive configuration when initialize hive serde in InsertIntoHiveTable.scala SPARK-17920 HiveWriterContainer passes null configuration to serde.initialize, causing NullPointerException in AvroSerde when using avro.schema.url Support writing to Hive table which uses Avro schema url 'avro.schema.url' For ex: create external table avro_in (a string) stored as avro location '/avro-in/' tblproperties ('avro.schema.url'='/avro-schema/avro.avsc'); create external table avro_out (a string) stored as avro location '/avro-out/' tblproperties ('avro.schema.url'='/avro-schema/avro.avsc'); insert overwrite table avro_out select * from avro_in; // fails with java.lang.NullPointerException WARN AvroSerDe: Encountered exception determining schema. Returning signal schema to indicate problem java.lang.NullPointerException at org.apache.hadoop.fs.FileSystem.getDefaultUri(FileSystem.java:182) at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:174) ## Changes proposed in this fix Currently 'null' value is passed to serializer, which causes NPE during insert operation, instead pass Hadoop configuration object ## How was this patch tested? Added new test case in VersionsSuite Author: vinodkc Closes #19779 from vinodkc/br_Fix_SPARK-17920. --- .../sql/hive/execution/HiveFileFormat.scala | 4 +- .../spark/sql/hive/client/VersionsSuite.scala | 72 ++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index ac735e8b383f6..4a7cd6901923b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -116,7 +116,7 @@ class HiveOutputWriter( private val serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) + serializer.initialize(jobConf, tableDesc.getProperties) serializer } @@ -130,7 +130,7 @@ class HiveOutputWriter( private val standardOI = ObjectInspectorUtils .getStandardObjectInspector( - tableDesc.getDeserializer.getObjectInspector, + tableDesc.getDeserializer(jobConf).getObjectInspector, ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9ed39cc80f50d..fbf6877c994a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration @@ -841,6 +841,76 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + withTempDir { dir => + val path = dir.getAbsolutePath + val schemaPath = s"""$path${File.separator}avroschemadir""" + + new File(schemaPath).mkdir() + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + val schemaUrl = s"""$schemaPath${File.separator}avroDecimal.avsc""" + val schemaFile = new File(schemaPath, "avroDecimal.avsc") + val writer = new PrintWriter(schemaFile) + writer.write(avroSchema) + writer.close() + + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val srcLocation = new File(url.getFile) + val destTableName = "tab1" + val srcTableName = "tab2" + + withTable(srcTableName, destTableName) { + versionSpark.sql( + s""" + |CREATE EXTERNAL TABLE $srcTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$srcLocation' + |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + """.stripMargin + ) + + versionSpark.sql( + s""" + |CREATE TABLE $destTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + """.stripMargin + ) + versionSpark.sql( + s"""INSERT OVERWRITE TABLE $destTableName SELECT * FROM $srcTableName""") + val result = versionSpark.table(srcTableName).collect() + assert(versionSpark.table(destTableName).collect() === result) + versionSpark.sql( + s"""INSERT INTO TABLE $destTableName SELECT * FROM $srcTableName""") + assert(versionSpark.table(destTableName).collect().toSeq === result ++ result) + } + } + } // TODO: add more tests. } } From 2c0fe818a624cfdc76c752ec6bfe6a42e5680604 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 22 Nov 2017 09:09:50 +0100 Subject: [PATCH 282/712] [SPARK-22445][SQL][FOLLOW-UP] Respect stream-side child's needCopyResult in BroadcastHashJoin ## What changes were proposed in this pull request? I found #19656 causes some bugs, for example, it changed the result set of `q6` in tpcds (I keep tracking TPCDS results daily [here](https://github.com/maropu/spark-tpcds-datagen/tree/master/reports/tests)): - w/o pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | MA| 10| | AK| 10| | AZ| 11| | ME| 13| | VT| 14| | NV| 15| | NH| 16| | UT| 17| | NJ| 21| | MD| 22| | WY| 25| | NM| 26| | OR| 31| | WA| 36| | ND| 38| | ID| 39| | SC| 45| | WV| 50| | FL| 51| | OK| 53| | MT| 53| | CO| 57| | AR| 58| | NY| 58| | PA| 62| | AL| 63| | LA| 63| | SD| 70| | WI| 80| | null| 81| | MI| 82| | NC| 82| | MS| 83| | CA| 84| | MN| 85| | MO| 88| | IL| 95| | IA|102| | TN|102| | IN|103| | KY|104| | NE|113| | OH|114| | VA|130| | KS|139| | GA|168| | TX|216| +-----+---+ ``` - w/ pr19658 ``` +-----+---+ |state|cnt| +-----+---+ | RI| 14| | AK| 16| | FL| 20| | NJ| 21| | NM| 21| | NV| 22| | MA| 22| | MD| 22| | UT| 22| | AZ| 25| | SC| 28| | AL| 36| | MT| 36| | WA| 39| | ND| 41| | MI| 44| | AR| 45| | OR| 47| | OK| 52| | PA| 53| | LA| 55| | CO| 55| | NY| 64| | WV| 66| | SD| 72| | MS| 73| | NC| 79| | IN| 82| | null| 85| | ID| 88| | MN| 91| | WI| 95| | IL| 96| | MO| 97| | CA|109| | CA|109| | TN|114| | NE|115| | KY|128| | OH|131| | IA|156| | TX|160| | VA|182| | KS|211| | GA|230| +-----+---+ ``` This pr is to keep the original logic of `CodegenContext.copyResult` in `BroadcastHashJoinExec`. ## How was this patch tested? Existing tests Author: Takeshi Yamamuro Closes #19781 from maropu/SPARK-22445-bugfix. --- .../joins/BroadcastHashJoinExec.scala | 15 ++++++----- .../org/apache/spark/sql/JoinSuite.scala | 27 ++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 837b8525fed55..c96ed6ef41016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,20 +76,23 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } - override def needCopyResult: Boolean = joinType match { + private def multipleOutputForOneInput: Boolean = joinType match { case _: InnerLike | LeftOuter | RightOuter => // For inner and outer joins, one row from the streamed side may produce multiple result rows, - // if the build side has duplicated keys. Then we need to copy the result rows before putting - // them in a buffer, because these result rows share one UnsafeRow instance. Note that here - // we wait for the broadcast to be finished, which is a no-op because it's already finished - // when we wait it in `doProduce`. + // if the build side has duplicated keys. Note that here we wait for the broadcast to be + // finished, which is a no-op because it's already finished when we wait it in `doProduce`. !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique // Other joins types(semi, anti, existence) can at most produce one result row for one input - // row from the streamed side, so no need to copy the result rows. + // row from the streamed side. case _ => false } + // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise, + // this join plan only needs to copy result if it may output multiple rows for one input. + override def needCopyResult: Boolean = + streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 226cc3028b135..771e1186e63ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext { joinQueries.foreach(assertJoinOrdering) } + + test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") { + val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1") + val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2") + val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.JOIN_REORDER_ENABLED.key -> "false") { + val df = df1.join(df2, "k").join(functions.broadcast(df3), "k") + val plan = df.queryExecution.sparkPlan + + // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj + val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]() + plan.foreachUp { + case j: BroadcastHashJoinExec => joins += j + case j: SortMergeJoinExec => joins += j + case _ => + } + assert(joins.size == 2) + assert(joins(0).isInstanceOf[SortMergeJoinExec]) + assert(joins(1).isInstanceOf[BroadcastHashJoinExec]) + checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) + } + } } From 572af5027e45ca96e0d283a8bf7c84dcf476f9bc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 22 Nov 2017 13:27:20 +0100 Subject: [PATCH 283/712] [SPARK-20101][SQL][FOLLOW-UP] use correct config name "spark.sql.columnVector.offheap.enabled" ## What changes were proposed in this pull request? This PR addresses [the spelling miss](https://github.com/apache/spark/pull/17436#discussion_r152189670) of the config name `spark.sql.columnVector.offheap.enabled`. We should use `spark.sql.columnVector.offheap.enabled`. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #19794 from kiszk/SPARK-20101-follow. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8485ed4c887d4..4eda9f337953e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -141,7 +141,7 @@ object SQLConf { .createWithDefault(true) val COLUMN_VECTOR_OFFHEAP_ENABLED = - buildConf("spark.sql.columnVector.offheap.enable") + buildConf("spark.sql.columnVector.offheap.enabled") .internal() .doc("When true, use OffHeapColumnVector in ColumnarBatch.") .booleanConf From 327d25fe1741f62cd84097e94739f82ecb05383a Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Wed, 22 Nov 2017 21:35:47 +0900 Subject: [PATCH 284/712] [SPARK-22572][SPARK SHELL] spark-shell does not re-initialize on :replay ## What changes were proposed in this pull request? Ticket: [SPARK-22572](https://issues.apache.org/jira/browse/SPARK-22572) ## How was this patch tested? Added a new test case to `org.apache.spark.repl.ReplSuite` Author: Mark Petruska Closes #19791 from mpetruska/SPARK-22572. --- .../org/apache/spark/repl/SparkILoop.scala | 75 +++++++++++-------- .../org/apache/spark/repl/SparkILoop.scala | 74 ++++++++++-------- .../org/apache/spark/repl/ReplSuite.scala | 10 +++ 3 files changed, 96 insertions(+), 63 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index ea279e4f0ebce..3ce7cc7c85f74 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,40 +35,45 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + def initializeSpark() { intp.beQuietDuring { - processLine(""" - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - processLine("import org.apache.spark.SparkContext._") - processLine("import spark.implicits._") - processLine("import spark.sql") - processLine("import org.apache.spark.sql.functions._") - replayCommandStack = Nil // remove above commands from session history. + savingReplayStack { // remove the commands from session history. + initializationCommands.foreach(processLine) + } } } @@ -107,6 +112,12 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) initializeSpark() echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + } object SparkILoop { diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 900edd63cb90e..ffb2e5f5db7e2 100644 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -32,39 +32,45 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + val initializationCommands: Seq[String] = Seq( + """ + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } + @transient val sc = { + val _sc = spark.sparkContext + if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { + val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) + if (proxyUrl != null) { + println( + s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") + } else { + println(s"Spark Context Web UI is available at Spark Master Public URL") + } + } else { + _sc.uiWebUrl.foreach { + webUrl => println(s"Spark context Web UI available at ${webUrl}") + } + } + println("Spark context available as 'sc' " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + println("Spark session available as 'spark'.") + _sc + } + """, + "import org.apache.spark.SparkContext._", + "import spark.implicits._", + "import spark.sql", + "import org.apache.spark.sql.functions._" + ) + def initializeSpark() { intp.beQuietDuring { - command(""" - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - command("import org.apache.spark.SparkContext._") - command("import spark.implicits._") - command("import spark.sql") - command("import org.apache.spark.sql.functions._") + savingReplayStack { // remove the commands from session history. + initializationCommands.foreach(command) + } } } @@ -103,6 +109,12 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) initializeSpark() echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") } + + override def replay(): Unit = { + initializeSpark() + super.replay() + } + } object SparkILoop { diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c7ae1940d0297..905b41cdc1594 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -217,4 +217,14 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test(":replay should work correctly") { + val output = runInterpreter("local", + """ + |sc + |:replay + """.stripMargin) + assertDoesNotContain("error: not found: value sc", output) + } + } From 0605ad761438b202ab077a6af342f48cab2825d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Nov 2017 10:05:46 -0800 Subject: [PATCH 285/712] [SPARK-22543][SQL] fix java 64kb compile error for deeply nested expressions ## What changes were proposed in this pull request? A frequently reported issue of Spark is the Java 64kb compile error. This is because Spark generates a very big method and it's usually caused by 3 reasons: 1. a deep expression tree, e.g. a very complex filter condition 2. many individual expressions, e.g. expressions can have many children, operators can have many expressions. 3. a deep query plan tree (with whole stage codegen) This PR focuses on 1. There are already several patches(#15620 #18972 #18641) trying to fix this issue and some of them are already merged. However this is an endless job as every non-leaf expression has this issue. This PR proposes to fix this issue in `Expression.genCode`, to make sure the code for a single expression won't grow too big. According to maropu 's benchmark, no regression is found with TPCDS (thanks maropu !): https://docs.google.com/spreadsheets/d/1K3_7lX05-ZgxDXi9X_GleNnDjcnJIfoSlSCDZcL4gdg/edit?usp=sharing ## How was this patch tested? existing test Author: Wenchen Fan Author: Wenchen Fan Closes #19767 from cloud-fan/codegen. --- .../sql/catalyst/expressions/Expression.scala | 40 ++++++++- .../expressions/codegen/CodeGenerator.scala | 33 +------- .../expressions/conditionalExpressions.scala | 60 ++++---------- .../expressions/namedExpressions.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 82 +------------------ .../expressions/CodeGenerationSuite.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 + 7 files changed, 62 insertions(+), 163 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3b722a47d688..743782a6453e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,16 +104,48 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val ve = doGenCode(ctx, ExprCode("", isNull, value)) - if (ve.code.nonEmpty) { + val eval = doGenCode(ctx, ExprCode("", isNull, value)) + reduceCodeSize(ctx, eval) + if (eval.code.nonEmpty) { // Add `this` in the comment. - ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) + eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) } else { - ve + eval } } } + private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { + // TODO: support whole stage codegen too + if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val globalIsNull = ctx.freshName("globalIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val localIsNull = eval.isNull + eval.isNull = globalIsNull + s"$globalIsNull = $localIsNull;" + } else { + "" + } + + val javaType = ctx.javaType(dataType) + val newValue = ctx.freshName("value") + + val funcName = ctx.freshName(nodeName) + val funcFullName = ctx.addNewFunction(funcName, + s""" + |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | $setIsNull + | return ${eval.value}; + |} + """.stripMargin) + + eval.value = newValue + eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + } + } + /** * Returns Java source code that can be compiled to evaluate this expression. * The default behavior is to call the eval method of the expression. Concrete expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 78617194e47d5..9df8a8d6f6609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -930,36 +930,6 @@ class CodegenContext { } } - /** - * Wrap the generated code of expression, which was created from a row object in INPUT_ROW, - * by a function. ev.isNull and ev.value are passed by global variables - * - * @param ev the code to evaluate expressions. - * @param dataType the data type of ev.value. - * @param baseFuncName the split function name base. - */ - def createAndAddFunction( - ev: ExprCode, - dataType: DataType, - baseFuncName: String): (String, String, String) = { - val globalIsNull = freshName("isNull") - addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;") - val globalValue = freshName("value") - addMutableState(javaType(dataType), globalValue, - s"$globalValue = ${defaultValue(dataType)};") - val funcName = freshName(baseFuncName) - val funcBody = - s""" - |private void $funcName(InternalRow ${INPUT_ROW}) { - | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; - |} - """.stripMargin - val fullFuncName = addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1065,7 +1035,8 @@ class CodegenContext { * elimination will be performed. Subexpression elimination assumes that the code for each * expression will be combined in the `expressions` order. */ - def generateExpressions(expressions: Seq[Expression], + def generateExpressions( + expressions: Seq[Expression], doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.genCode(this)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c41a10c7b0f87..6195be3a258c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -64,52 +64,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - // place generated code of condition, true value and false value in separate methods if - // their code combined is large - val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length - val generatedCode = if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (condFuncName, condGlobalIsNull, condGlobalValue) = - ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr") - val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = - ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr") - val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = - ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr") + val code = s""" - $condFuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!$condGlobalIsNull && $condGlobalValue) { - $trueFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $trueGlobalIsNull; - ${ev.value} = $trueGlobalValue; - } else { - $falseFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $falseGlobalIsNull; - ${ev.value} = $falseGlobalValue; - } - """ - } - else { - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - } - """ - } - - ev.copy(code = generatedCode) + |${condEval.code} + |boolean ${ev.isNull} = false; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${condEval.isNull} && ${condEval.value}) { + | ${trueEval.code} + | ${ev.isNull} = ${trueEval.isNull}; + | ${ev.value} = ${trueEval.value}; + |} else { + | ${falseEval.code} + | ${ev.isNull} = ${falseEval.isNull}; + | ${ev.value} = ${falseEval.value}; + |} + """.stripMargin + ev.copy(code = code) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e518e73cba549..8df870468c2ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -140,7 +140,9 @@ case class Alias(child: Expression, name: String)( /** Just a simple passthrough for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + throw new IllegalStateException("Alias.doGenCode should not be called.") + } override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c0084af320689..eb7475354b104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -378,46 +378,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = false; - if (${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; - if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = true; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; @@ -480,46 +441,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = true; - if (!${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; - if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = false; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.isNull = "false" ev.copy(code = s""" ${eval1.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571c..6e33087b4c6c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -97,7 +97,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } - test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { var strExpr: Expression = Literal("abc") for (_ <- 1 to 150) { strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") @@ -342,7 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } - test("SPARK-21720: split large predications into blocks due to JVM code size limit") { + test("SPARK-22543: split large predicates into blocks due to JVM code size limit") { val length = 600 val input = new GenericInternalRow(length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19c793e45a57d..dc8aecf185a96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -179,6 +179,8 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + // The generated function doesn't have input row in the code context. + ctx.INPUT_ROW = null // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) From 1edb3175d8358c2f6bfc84a0d958342bd5337a62 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 22 Nov 2017 15:45:45 -0800 Subject: [PATCH 286/712] [SPARK-21866][ML][PYSPARK] Adding spark image reader ## What changes were proposed in this pull request? Adding spark image reader, an implementation of schema for representing images in spark DataFrames The code is taken from the spark package located here: (https://github.com/Microsoft/spark-images) Please see the JIRA for more information (https://issues.apache.org/jira/browse/SPARK-21866) Please see mailing list for SPIP vote and approval information: (http://apache-spark-developers-list.1001551.n3.nabble.com/VOTE-SPIP-SPARK-21866-Image-support-in-Apache-Spark-td22510.html) # Background and motivation As Apache Spark is being used more and more in the industry, some new use cases are emerging for different data formats beyond the traditional SQL types or the numerical types (vectors and matrices). Deep Learning applications commonly deal with image processing. A number of projects add some Deep Learning capabilities to Spark (see list below), but they struggle to communicate with each other or with MLlib pipelines because there is no standard way to represent an image in Spark DataFrames. We propose to federate efforts for representing images in Spark by defining a representation that caters to the most common needs of users and library developers. This SPIP proposes a specification to represent images in Spark DataFrames and Datasets (based on existing industrial standards), and an interface for loading sources of images. It is not meant to be a full-fledged image processing library, but rather the core description that other libraries and users can rely on. Several packages already offer various processing facilities for transforming images or doing more complex operations, and each has various design tradeoffs that make them better as standalone solutions. This project is a joint collaboration between Microsoft and Databricks, which have been testing this design in two open source packages: MMLSpark and Deep Learning Pipelines. The proposed image format is an in-memory, decompressed representation that targets low-level applications. It is significantly more liberal in memory usage than compressed image representations such as JPEG, PNG, etc., but it allows easy communication with popular image processing libraries and has no decoding overhead. ## How was this patch tested? Unit tests in scala ImageSchemaSuite, unit tests in python Author: Ilya Matiach Author: hyukjinkwon Closes #19439 from imatiach-msft/ilmat/spark-images. --- .../images/kittens/29.5.a_b_EGDP022204.jpg | Bin 0 -> 27295 bytes data/mllib/images/kittens/54893.jpg | Bin 0 -> 35914 bytes data/mllib/images/kittens/DP153539.jpg | Bin 0 -> 26354 bytes data/mllib/images/kittens/DP802813.jpg | Bin 0 -> 30432 bytes data/mllib/images/kittens/not-image.txt | 1 + data/mllib/images/license.txt | 13 + data/mllib/images/multi-channel/BGRA.png | Bin 0 -> 683 bytes .../images/multi-channel/chr30.4.184.jpg | Bin 0 -> 59472 bytes data/mllib/images/multi-channel/grayscale.jpg | Bin 0 -> 36728 bytes dev/sparktestsupport/modules.py | 1 + .../apache/spark/ml/image/HadoopUtils.scala | 116 ++++++++ .../apache/spark/ml/image/ImageSchema.scala | 257 ++++++++++++++++++ .../spark/ml/image/ImageSchemaSuite.scala | 108 ++++++++ python/docs/pyspark.ml.rst | 8 + python/pyspark/ml/image.py | 198 ++++++++++++++ python/pyspark/ml/tests.py | 19 ++ 16 files changed, 721 insertions(+) create mode 100644 data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg create mode 100644 data/mllib/images/kittens/54893.jpg create mode 100644 data/mllib/images/kittens/DP153539.jpg create mode 100644 data/mllib/images/kittens/DP802813.jpg create mode 100644 data/mllib/images/kittens/not-image.txt create mode 100644 data/mllib/images/license.txt create mode 100644 data/mllib/images/multi-channel/BGRA.png create mode 100644 data/mllib/images/multi-channel/chr30.4.184.jpg create mode 100644 data/mllib/images/multi-channel/grayscale.jpg create mode 100644 mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala create mode 100644 python/pyspark/ml/image.py diff --git a/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg b/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg new file mode 100644 index 0000000000000000000000000000000000000000..435e7dfd6a459822d7f8feaeec2b1f3c5b2a3f76 GIT binary patch literal 27295 zcmeGCcUV-*vM>&>VaQRUAZY{yCFh)xBoZV?5rzSVFat9~&Zrp3ARwrqf*_JqG9p28 z5{Z(t1Q8?%f}s4?pxeFAzUSQcp8G!U_k4exhFPn+s=B(my1F{8qmiQ-;KT(jT`d3t zfdHxCA8<4aaB2D@odG~!AK(W7fE1+B0Z zpO6qwBv4`_0G|K>8-Wo(rR{^oIrySs8V;UbST+JPfJN8S!3BwOft^QSkuE5hi5CKn zbnrl8acl%w0COJcg2Xv^!0=oMd6=my0;Z3^VbES^4-k)n>HA<2KAteNGfWMG6NDKc z{9smSjJpEN6^HYZ7ZLIE^Aq+&;5?7ngwYrm5hThPjq!BAA<-xi90KEs6>>m13Bl1Q zCnO#RDn$g4ipvR0IS9fW*$7Yo$A4cb8-YK>TEhVc^F!iXVeUvA7KTQ_oE-29JHZ^$ z{xGmUm;(yw2{yyi$I~-F5GE}xITxNRAqrC$))1!BbP4dnAstYHu!{~j1k4D5L86^N zx)BcTf$(%d!hfNtG|*mN2vC-X2Rc9jrXvg!6NAA$VP_>_QX*o~FeFO&oC3^37$zqy zd5o6;8U7?(1oN%0yrc96S-GnwF*@iVe^Ni~tkB5HJ8>AnXNT0W^RDVG{rc z1OOg@4~T(*JQxrU10(=YXFwD$?N>EM7&O}11cCGMGIVrF9&>z#&n`B0&<_F}H@kmEKP~Dd71Dguut2cB1=>oeV)o z1AQE%pTu)&c;QfgG#xaJBj$I;*afTqJEMU?Y5bW%;r`5Abo9XfrWm;3e1B&&Jw49< z&VbhcQ?xqV-Q|}dj++Mn#@gx{pf>oB0vJ2NU`}WsM^*F;K6{+BG2npu8&2Khzw)VL zoJ`G7I4y2t54=bGPU;>`uz!X(#(Ln6@kah0s+fj9_{n`i&qQO+JK!AfEjU)v$i)cz zhcDq_JTLghs|{R4IR6y$w`xtnSoqJXHGzBl#2aDYs#bUj$LRzd3~Lsy2(WWOzXn6% zALgUR6O2Kf)X_K`+S3D#a`{z+@(&7Lg}>mb{=gd}U0nZ4p!kCT>h_x-AFH8jKXt%I z{NuR(>wO%@d5>w-zbXY^D&fGfzd^T-apd4Eah!f)38{{;_~_4ZT%P40^U&?%JlXLR zd_az6@!>rd^#5O@8$=)LwR4Qs0YsdgL5c_XBtFJKeqkP^XZ*&9gA@p04L<+ikvhiU z$HOnovCKU%58%h!Pp$v{gf8RBf5-mJ|KHG#Z9$5c{m1AePz6UT`1@ry4m<=L&zi@) zr|?-Qz70RM$7ccRO8~$abPVEK`(M*vaCkmE|39X`F#IO`87=^nj|V;$f|E8t?1=V3 zIbr`cTH!b`kanyl-o=0W5s3Vo9|G#X{q?6B#=qV1xCz8q4_xxY#B5}c>eZ2kwMU2LP%Z1;uQqu%4-S8L(%Aj)a!J)MgCe}{#{@zClNpb}Q*=FDVBqWjb^D<{7>{4m%&$(t8~-0M`dF7grlOxb#10-fQwNuS z;G%|u6R;`5AE%4eG1b5LN7y3$1N{%oWUgpTpqdA`+WHmIXnrx7aA+@Z62~I`8mGuT zz<~B&u#}F+0r0=zDZnmv{X6^d2s(?e2%K+0cUIXrhJVhJ4976ug^!~Q;pyYN89t9U zG7X>xIDxadBY*(#GrKDY&G7y<(aK>vT>fB_od`al{l5ICBnG}6`3us1R`)Y84E zd0Yjh>O~|9O#lu%PjGGoCqS5`l{JjyEjaARz|js4Fb6o+OW#BbzXSnvH8o&ZkO>d} znRbVNj{gB69U~a*PyGJ~qJ?{5z)2Ef1aeFG;k@whH4tWSbU%jg;$aM^6o@ax!!AGJ zhGU+e@E{&`@r75Wa`Y*4!V1$O)k?1aF=!3Em{$PdSXo)5kO zaDg0f>Ub=R3Ezi*Sq>u+ zaAACdf2aQ=!*9-i4g94(VSIhkVsg^bLgK<=;O^>pIK%Ls1}?JEJ{TC*%K?sn3H{x% z{?`TnqSjyZ5Htoi2N(njtjZkp0{l)-$IpLw5X6I70&uqq z!=p8fO<`~!j4vMkH7gQ-0J4A_ARj0OUILXsEzk(G0$sp6UAgwg{(lfAm0d}1e62}1ndO71R?~o1Zo6&1f~Rb1TF;L1c3zC2;vCt5{1Mu$gdxaF%d`@COkU5gU;Jkvx$u zktLA}ksnbMQ7Ta`Q8`fy(GbxGqFrKQVkTlfVi{r`Vk=@d;vnJ};s?Yf#0|uK#Ph^E zBqSs(BmyK#B!(moBsh{Nl5~Baa}zOI}RgLOwyhO+i7y zMIlFVi2_LxPH~swIYm3gEX6)09VI`d2Bj^f4`m!>9%TdNDCIU46%{X)DwQ=Ajw+5S zpQ?#!ifWIVo?4h%m)eOsjQSpR1$965It>}k85%X3%QS&BsWhcDy)>Urke%Q?p>e|D zMCgfoC#p}pKe0ngPb)@iNb5-(L;H-jgLavYgpQX^lg^3m8eKMB6Wu&Ll%AVjgC0(Q zl|F~Qh5iErF#{iiE`u9GG(!=?TZRorIz~xGb4Gv0dyKCcXPBT&yiB@G9!&8}FPVm! zzMW)0sdf@^^47`Xll>>ZGP5$PG9#F8GnX(AF@Ix$v1qcmvm~-qvP_;LJjH+N(kZ`F znWtJ#ePU%`RbYj)-e!HtI>tu8#?NNV7Q~j#*3GuV&cUwD?!}(Y-o*ZigNZ|x!;Rw( zM;*rsCq1V!CzA6HXFcaf7$Zy#<^j74Yldx{W;?BO+UNA6({E27aPe`OaYb^Ka7}WP zbIWo&bKl{9!@a@7$)nE`%u~QKc82VX>>1ZHX=mEbeC6fiwcx$MTgki3$IPe87syw@ zH*uEgtjbyMvyaaXo+Cadd(QpbgLCir3HYV>UHR|v_X-dQNDCkZ9tiXa5(&x)q68lc zz89htQWNqODiE3#W)i+A94TBWye`5cVl9#=(kAjlR7%uCG)Ht?j85!=Sfp5u*p@iI zI9&Xm_@D%pgqB2@M3uy*q<|zsGE;Iyie5@z>W0)CsqfM<(irJt=@l6snJY5)W!}p& z$QsGU$hON7$*IeQ$<@m3%S+4SU>*}KF81-`XuNv|iAsP+mq35;F-#-6V zlTOo2GedJqi%ZKzt3Yc*TS_}fyFrIYM^7hF=e;hwE?hTXcjJQ0g^&x)dX#!5dKr3i z7x^z@F4pK1=wHxJ)}JsqV}LTKGz1KF4U-Kgjrfebjb2?MzGQUi-laujG2>w4HWPXi zdy{7-yQb=<@uuTuXU%-f-k8&v+nVQ_f3-Mokz_GrDPkF7*=2Rg3TahkO=4|sool^g zqiK_3^TAfiHp+I`j?d2DuH*8l%O00s+tb)P*q2@*x?*wV*_A^FLx;x>JC3@J_Z>Ij z8t^pus*|eI9j6t95+Vt)N;;Q1B>iP+(fy_W|y6L!OxqWpvbkB1? z^04qI@g()U;#rNNL%E}x(Hv-hbf1@iSCrR`x4d_X_Xb7}lZ%C5?XXoi25_;{?Q_=W zy3eeyvhO|LJwJ25Qh!>1lz(>sf55GPr9iE~CqaZk@Sr!rr-LJcXG7FN9)&_e9Yf!Q zafe+C`w*@j{xpI-!abrVQZzCt^2=3=t2NiyuZ3Tmzpj0~Ac{H)6ZQUv!i}t(gg22l z-`6NnQ~3B!phiBFSglLC|G zll7A;?!fNczVju;KBXg7GBqoWJk2L<=I+J2Rp~tG3F(I!t{Fr3)bEwtXS;vv{?`Xi z5Bf7zGmEp>v!b)UJw!end8GZQ;xX^z)NJBx-|WR4^PILPvQP4JPvu7Ee$PYY%{(=J z+MF+wpZ|>QS$qLOflt9op>5&2qVq*n#X`l8N|;Kbp99Z*o_~CC<;Czzy_auFWlM|8 z&XnCRr!T)<0jUV6*r;@^oT;*^>Z`s`{ia5tru>!gtGwFNwf9~#zK*XWue(-vR3BKs z)8N(csnMlz{*C>c@g~cr!RAZNZ(Hf4ptYdhpSsyk&nE4!q-D!Qe* z%X_4H%HK-At>~5Mt$HW-uBK11udZLEzj5IFK;6 zqnl&CV+Z5m6GRiYCr?bKPO(iro<28SG9x|ndiMNm_ngVx_&j`mbpg9@_~F_j)ne)r z=TiQ%_;T%v*2=q&HXlE%daWLOidv&xyT5*Rz3j8<=k5*5jSric&7-ZDZI=^A#?RxGWe7(KLviEdfcE9bL#kZvczeAG4yWh`!ulb?(W8%p3=;$wNA2R%!T^I=m z*QvkLKg0bd{!f9U*%P+E){-c!u!9rY5g`mmdy4owc!`J!i;4h>D*oW=&K-e+IU>Lv zu@c|sSM_`_q>~b#g_ORizLy5V6{!=5L6`&@n8E|y;c`xVD#}#kivIHco?hU2xdY7K z(*uQ-_gCUNmM#y%c(e!??3e}TuEh28NM0U)H(&}LlYyts60%&N5HdxKle4_>d97bE zz?>4-Usk?;636H8;ObmXL{wZvTwDm`5W)taa1Q=LC@lA}jGsczgKK;Y^7zaWh8GH+ z$NAutxIiV2wbS*&!6iY3HP9Hu?_031 zll(7B8-sp>M*X)<`NeW9;_o&57hB>TgG^Bb-;00g)t}hEA^%xdJHdZ zMUlT(qX;7Kga4Pq{!t8Hx0(m;Z*}86Sl+_{<)XypFZAoIg-iMW>Kcjse|(Qb{{QA8 ziTuB}=l|VJ68T>u$Uk1_frm=}v#X?k+M&O@=D0)uOaI|J6^F!mApW0s=l?*vKNb1k z4p$lDw=D2)G5^QA{$s9x z%L4xv^Z(y@*PkZ^2o!ic;0K-q9F2pQmjn<9{{L5k5<&5an2-=kL_!Q6xPw1ZG72(M zQgTuf5^^eXatcZ?k&sc-P*GCjVZ0E0aeOK8pOTb>6wmSRrK1LbmK11*2tpxT00Auo zN((t^1~|c686xlk1WbP%m_P`igha%kO5_wEL-`3%eJGRwRF(+;5h5f6qyvPsM0BUc z)QIUXIgoI9Gl)kdJ|g8-f6>Tj+_%9a;fRSOBWF6v%yQ}sFW=d7{E|}AGO}{=8s{~& zv~_eZn3$TGTYy(ZaB#2Y;)-;`;(UDl`~w27Ub`N3Ml(z5c3%Bt#`H%-kgt!?cco&5uYL&NV!M#pC7<`+IJE-kNo+}zs!va|bjZ~q&< zT@V2Jvsu4e_P2J?g6$#z`+$%H-!2G&ANYgP5)z#jBc@ZkMB?C0&m|r~%AlV3=tUzL zw}kNqqa&t|oQX$r_RJ=}X~!-5pEWG=zV`9d$YjXlJOHi;`vh62i4op`N}UQ908@~sT1C}GtT;MUS0dF zcIFbb$~Dt`CX)Im);6LvlWvJZrK*hCmAzB>1fHn-W>441*k$vxU+xEbA3h~`rWL^I zpU(unUCCZd(^)&XN}?F&N^rAMIVyJypMxm3~H zvrw^za$^C9=i-S;3^$8ZUPvqPOTUC|sN#4v6kjYObSykm20-)vYX*nX^A!Zbe8|(f29! z{h~vaA{nbU{Gsy9_ zxsFP4o0^QKNh^|Ab;L*4S$OpK zgNfr3B8F-G;Q~({UvE9PO8$kyJgNDmX@MiKRAP-BA$OdR;7r@eH=d?2uWQS|j3osc z6O@dGrsqY66pYN3P?)Ru?L7I>+qUuU(tSr`fwaBs`Gc?rgu`SR~l;AZJ!>zj`Ujz)-F@ypct{i$~K!?n)EIPJZ^K3uG@pYPOn72z_Fl zyxKuX{-Q%dS29wz53c2V=3ZQMxO}PjQ=)X!HcF&7C04r|a~7feFy{75oisA9Ww$qP zY9;j~({^y5Q0IDv_`O})6a|{Yvi(q0cqDwjFPQsbMe?kj@URBAE*e2;%)*M5xCW%k zvhZBZBzcE>cHO`w_st#8x&;weo}3@O^H1E*eQUKnnBH+=1G3G(W|&eMgmwvril zyCTcV3vg>u<)+sFTq@OY@*_Ws$4JTKs!cpH`Bi0+`Y_SwlnLZQgsiY>4`i7h6}`OHqE&!$CNfrnow4u zV8AdZID5sEc_H39MM!3FM8;*^d`M0(s=lEjpDODwA)-nCp+vAN3zt)AmB&ToWto$ylK%|DG~D-G^$<6E^o0vgaakTda78G&W| z8BzJC?5D*z6?9<}Cz9WC8=s(V7Mpod-23UgZ}B&;ORQB?X%AKVRLNe=SzF+dFx9o#_7lbI;f3lB}iI-NQb4 zETuecNjs2)ye*H0@70svwxJGj&voa#w?0UR93R;Ty#smsI*#N*%*y$!S9Yw(YWr8; zOoj4hzEAh(s`{6CZEnT5d?}#SjPq9i5NLkQmtTJ9W}&D=+*znHmDE)y(5LgT*0d{! zB?ck<3)bG~Rg-RFr@&B~yjcr}fvD*k^OErq2HT-)r6}1IV`5w&CSuR@-N@OMZ@1s+ ziLnWBd&J2OzIt_nk->*26X3|Zy|AMaH^n&6`%P(RxQI_xP^kO%C<}Aqu2$2eu0@#8%gW`7(gMlfsD~C$s+7fU_f|B$q4$`l!3IAXmu*`g5{qHGuVY9YeFVgu z`rfZKjuISp?Ce4236L^V#_VUhT-Kgqm&ab>@6$?r!W;O~njmyJFC=FD?F6>JHcEuq zzqI+JxOSfZ1IN^pV>&OjxnY2oSlmwXJeno=LSI>Y+34z);3LIsCagQJz^3ua*c!&@Ci5pt37 z^Pw$Sm(Ij}J_5|-JTQ|2i+WjO3UiFbnWb%Wp2B;jDHI&Zt_7T(a*plH34Uylh=z&4 zyXz`$BDXYi4>*dlFOycvbUYcq*i3olx&9N34BN~SN`7E>lLpTHQO85xwBw4-vI_Zr zSCf6&*SKQI*4nC?jqeNI33(j@F#^&Z z=}tAl#x|W#oRR}Fv{k#I1lE$2kXbTL0jC3~zn}f)Zo#gdMzi((-R*N$LQaEcyt)Gm zH13P1lR>0UykhZXuTmn8J&<&j$>bbzc%3vE+RIRhsGAlP>pe}*!{^p@GQE7Bk5KyL z`j%L)pnB@fF@^y4(Y>KfJ-zpmB8=Zt=X8a+UwR12&~Bev3WqE5V(+`;f5K*k9RZR4 zhq6_Bm77^jffwH=>p$rgLl)7-Clb20x<@i>u+PQBY#DZv&5)1nt!(fL(a-rbR+h&W zjd^|4?Cc)kmNvV~qotPE0ro4K3s+sn_c+M{PKrasMzl7MS6Os`<7uSdh4hMK0#&$i zn=gaz*7}6l*1Y_){G3}U(k1ZK)&a`m?gqq5e+J>Cs-i?h8^dg#)7QBgY7-X8^2-Ii z=s>E-au4MPaf|Vvb|fLUafk6mTHpLNY>MEC{rA2<$Q)#dQTgD`V9g(pb`%kso5&Iyz{0MjxE*Ft~AMVMw>pVSc-&w=^9QeAb zFKs3tvHiL@X*__Pf8mt51-O8aT}E8S873aJ3BW$aq`Vnhvb68OchyM4Bj#L6jHuIdouyX4g`fvI=n;&#n;rg!cYg$J&aerXfchkt7 z{jr9vQ`*bU4OXsW64&+)dZr6h>A9*h8(%11TuY(`h@MNtTuBkow7X`K{E}spzGb(b zqLwkF#NPVjNptJKf{Up}`!BD%P`qN{Jlo#!Nv=e`?+8G?iZu?po6)OT>TmXhSl1$XYVfX9m^Ex3)})DOx&Im7hgP z$lbe9Ym#Rj2Wxeo>bx5n7c6aKE2yEQx$7WNdEbeSjD-12uB5xcl35IYJz}$e(>4Zy0U+W2Sc@-ZD?7~?Qr(l01pA-fn^K9aI-Ap7^C z^=oS?%PL=#m#eX*(b=Ed=H`iQqe!4`Off$#*3qil+9zrGr6huv*(oXL77YExGRJ>D z$z`(Ak6PK>gSgQlvcbtyMP@YE_TDVIagDoqz|F*FIdCDN0GS5staz*OxFpja*T}3H zQ!?c||FrPIhm92G;=Jb{GcvO{N+_`*`P?x9x&AYrFPk%LOs#iaN%C1eZ^E6&0iNm}Ubt3U5P}QS~IZ1N0Qm!~@dZsu_ z@69l>GP~%Px8&d7GOD6kEy;~Q><*%DtE^=6;9+;^SRTdBRczp=bhvfPnv?xndZD>V%&wf{E zoRS@K8C03&t#W0Wdf%JqkhI(mq-Tnfk-t2f(8#=CF(E#-r`?z}l3HUQQ+5QLuzlTq z1#Xri{HQFR7N#lg=}6wjojTjsBwz3w4Y-c*_Y1czkobkrs}fg zY7dW^cx-pi!u>&ZrfSK#0UtESo}>QPzU4sbTY;;dSgMxgfMU{p#pT&+h#bl{XM`4& zM`f6x`Iqd7lzN{kORR8kf-|N!)%I{-Y$*>B=~%vGyHIi2C>z!4t}5v z|6UucfwHtcl)Cw>=AnQ2?${`hB;g(tH}9u54YNR3=1t7)N+{nv;FE1wxp;8mO9^%V z5r8J%et%%t;Fl^^;OQG}rJEi$BScQq@cd99sObo(L`V97U$i6U*i~E|t=1l;ip7mu z`O!~bmt&WRBnr~;6BupZQFE#KK1P%gKU=^-v&oURNunBW+J}BTDCp#v>EYIC7KJG+ zgr1wvh9tJnT(#EE+y3e{oa?M$6;@hvqueOa+qz9(NOeG5DgFsmY-;-4K%Lsto|jbz z{^!X_uNA> z`0PP1<#J%!$iT&+;=|R78uPO~ti%=3+y`3)w2{=y$%Is+`z54}oM zUKPhK&)Rj1Ejm5Vvd@gP-Dmreic9C?QZXJ2YJsV=bCXfGFD*LBqg(wCLpCGzW%TQl zD58S|?WVO+yivnU&bJ;5TRugG4w_vcK0Azg#Hq6Wjb==v>8tgP2AfzBwU<@SRb^W< z!f-6xGYrGCibW5nwJ`%oJL?^pyY6GoSqChKtl&zZ>`~PlGdTzNy~bj;8hM97kHAKhCfKnK(kB(Ee$3OufTvE0ZB= zWcE%HWC9M%(77|WZk0+Vm!aZa_mZF3xz&@e^KA4?-3wv3`yf@D{~Zjc+O2+PgvF}P z_At6qj)#*AYzq_14+6_Qm6k5vY}Bxd^B{K;rler!<8d$JiC*1a~P z!D7k%TE1S~wRJu+?8`VA?i18YXaCWN3xkG2aekQT*vqh3(jomG?v_c1K7+)QpdH-dx=63svjA*B@ zuvJ^4gk+HI4BJ(<*v6#el)A<;a7uq8KMo0-%XAw>@hyfF zx;e`-e)Jt7G%1n^FZ|L+)hkrrRItaXjVFZNc)UN|M01@(Wbb+|4W<99H@hn2hbmId z9dmNEDzNI;!`wdSczoOUPv};OE64@sH_SXo%WX~{(o4)C;tHhlt{&P-T2{E2{J~*+F83UpcPOUZpKxMbKM@#KrJXYG?4c=pP-%w}X5yEzD62JClNNfpA zti0oCyAJUBZrGd&^&n;zN^;9udoj0ME3*Mt8Z3UB&T_(?4{ zBBJKPb>Fekw}rNImfAO?BRP07KDTo4PG--7J8MH^*lzn_Yiy!ULPTse^W~F(pOImS z|I%#QHDyu%z$Lv+#-iO458*ROD&jNaJ@oS79}Kw(VY1LllMaE@fko>Z;^pz=8-gmd zJQXiC`J(2NJ{!sy*$0TzNL+kfX|7Glx6SaN&eH0H=atTe8S#{FXZH^D_ljm%3d07z z*o4uJ?1RUTokGQ2y7XE6vnRPf@I|&vvM@+u8=>dlJ#AZZL6^HCtw%RC@(8LF)}uOb8YRhifha!&bEij*$ zc8<4&ab$lnjA$xYVAJ41&c|fY0aP*PClhV)!Zf#WhV1qacY&$P6p`M-Gd4UHm2O#N z<@O}V-Ha!((Y@}ECj!z4hSh4ss^OT|r_?F)Q&e*v90nH(aLf6k#Tr0$_+L^P9 z+wt1mD$JO}o_Ud&-Rea9;YG^D>kFTZ#3J2WaP?6toQmA@YAUxBEjD(?FPKLL-HQ(D zSgZ;UxR!HmEpC17ofb83;=AG#iPz2gX={C)^g?J%HeKRSY34QC<<;ES4`uFDO`40R zNxVX-A;x!=Xu2h0;3n#Ql3)4-OER|))X_rvwbw$|Os~W~=5^rpmu)rZi^(+1@*i)j z(T@Lerq2d_>l%qq@xwY>`V+cMj8_^1-d^fWZ8$NN_k1HY`kB>=vcVJakGg5CEjNY2 zJ2*1ii(#)#fNoLejR2u zFA^wCvte0hQFOu6=8K%6P-%dBkpwmak6Fli$dOfiNV9QB1l7!&lBJ2RP}0_8K!$F_9UjHI%T7pNpBb47}?dNEIoerGr!4y&0Y<;Uds z!s1B3mULK_h1jbO*mj`$b*ugKsMiX!nsr7b-VQX{0;9zBI`uPIp_0jS=uagNb-l%E z>Nmc;=isaf##oO=#NgOrNsxg(eVoaF#;LCS<)w?glWe%0A6-1)DrU)!b|C6=4U$tI zm??5|#MY0hEJl#Qlj)0wfR=s_H6IHg2{tX&-B_ws$)pr)^HQvy?atcS& zkfFn6Xp3jCB5l~vmx6>9@*AuI+3~vbiPM#CDY*LeALU5P?xtwHrS$`|o4D6)>!82_ zl#mROme|P9wE-Kyyq#QPw!lj(pZ5n>pB%nc`L}JC?OD_26QR>ee)6$9DURPXc8q!pe_=*1mlYs=qvW@59PT?s_Wf zMUmb2cG*wn??OY*oG9ZSO z)HZ62Le?Z+km)k{Vb3OwS$f&N&@Ay3&R!H&GbVjhTMHk+-Yc2@-`c`Z_i)f`xJnN_fV-ucfw4`g2GB6x#n={C! z9v(*f+(dD|nUg(iX(G5#tp$r>6C>PgW6Mejn{{$$dLQu^aRXpfBVQBr?qCg3$&*|O z6?P{}ZOt4d=HsGcX(Ztz{L04W-Fv~_!tVBN+u)e}TInXIW{qx9u$siatK;ZXEp`MS zziGBG%@og#4iXh*uk-I!Dh-xg=e%j^8ag zAg0UlAa1Rd+9H>90hiZ1Uu4kC z9hbGWcr3!9a(Y|GQ%G7%OFcGMkTTw`x}lXAn@zLUws}s_nsZO55dHxgRaWlXZufAB zWj!zM!Irl@U(@`-yllwu8t>;C2Z$F5TYm5Rjc=wFq4Gm0CqWb=u{@7!)smuj`>xy% z(-1qxHJV(Xu<^9H_%w>fxh@<3(J^V)FRQxDhtDX&p7A8Ga)iebYZ?uLTfCCn7UnB# zq33F2T6j0!wDyr*u08@RG8;TsV!jJB2-b%A$7a18>3O@z%PZKVL_G}6NPR!Yz7XA$ zlCA}d{rKI>K%xrwWp`bUDWAus7rPwG{ev93Nv}6p^qIq5q4jzM5>YXng2`x8H=IV6WtiZi~NRt^gL-f%_e%-w4E5cz4paV4r1DoW_Wy`=kt&t23} zDsHgytLBTz@xl3g&LxH(!`8sZvi_gBN0`H3j=m`xsM^1IsQV)#h`H`Mf2y^_lxLVH z&CRjY2i{InU8~Yn<}_i93(>gO5f;u35Px!i;6CG9&GaJx#d+(}BwyMlXNv{B;Iq$X zsy$^sy(Onse9Rths^rU-I?pEa7IUaQn6Cc4)O~3^+JBy9=^bvR#bwGr+9-9k``~=E zx9IqgSrBAAzG=Ex-(qNh#-rDJI1E|OY1`uM%^9s|TIna|5tP1@{@Gcaj%7oyc*ye^ zDft4fgKfx8X18W92Uk)m5?$WvL){BYJhi+;q?$OW|^5sF{3=-d<*Q%AN2Bc@hiEP8E-y7?A5f5`^gT3#;d8)vo~C%0~0D^ZGkF4wCRu-ZyxGKf=ZS}=2R zh-px9)baQc4fQX0NDKK@wZZjC=NSHyL#-96jSxWe-M` z4&q2TZxn@Ag*8@w-R#+zo~z4zZDBMZtV&0h6F?txUGeHG-vGJ!h^rYNbhG7@X!?jt z3BOObn_O;}aB(flt$=5xf#>V`(yBDCdWaQU1+3uD4hHle&ueUxx23|)g( z$L|{1toCQS8`Ic^GL{LMDPWE6nXx}*S=?Niy|?$moZgU{n5pba8Q~YDXA{1mch;Q| z#p&BwDSnXyy<&;w# zS-EuP%SFYDB%Q01WandZ65Gv;kbLm+v))kq6U19;>{8o3wMN z-K3bjceU_)>J`=7un4-zTva~_xfC;h)oz+8l?|i#==zN(^@z0qPhmP0TK~rrE2l1U zXmo$25KL_G{a65g*|X-9zI{K6BBp}jhw)zT4Bf}eGBvL@)`og-J*rB4>=%DQT0|t} zbfuL!ZDNhn&E9go=f>-1w8>#gw&&r(=I9H_NsHPKrrf@c)o4rS4CqQmm`lN{Pf#+N zT^4;92kfiP1JnUDCy^3v-aQ{~@g!+PrV}Q|Er8!nIP{_(WZlL^ukkMG+4xVn&O^+d z>0pF2YrN;xt&~^Z02f}cFiUlOwD#P&lhob~j}wE5hxy;<(^`FBuS5-5xs)>WPTy)Y z(`GJ&4r%vvz1xyw{(dDk;(n+c_dB1Jj4!e7>Xqe* z^#g^DuJ4{BdCjJD#(NzxYM%^rJ6aqkR(96}M8*oO91_#tbSIs2B)t@;Cd=bQGY8wW z#-8x59~mF?gnAK1e)O&AOA?po_^~vdQ8-NLO!}@uI>{26ue;i zFs`+#HIn3csr8uNGe`#}t|}OW+HxEo0kNokjG=^l`>WDHXaY4hTtyJK;Im>??~uyE zAR!sk_AYLzUY1>~ZR_!j_*rHL`1i4JVncYVHphNG%V2XuVNk2>LR^sd=l9bYGIh3R|dBZy^N&KE!w(Z&MXvhhV%UiH10XRCY7|w+t(HAu~)}8d(^)+-lJr> z19|^6S2wq}6jQ*T={uis<5QBB^1?k+KXxVl;_%nIeN;6szw9!)snC5YwRE01v5@dg zJVZ%}d?; zcJ8h~$7=>6~D4owtNLBs?3XXqa)z zlv{VZUYyOF7{WM6!Kz2`*{BJlQcC?XkapRy$YUDWuPWU^^O*%;VS}{o9Rd7HCGR^^ z=26u@IBry02Ngipo`3v?dhI~;;aTT!&`nCn_f`gkeP1{?Yy(5e!^k^m?QUl!S&%xW z-8qVC*qM*pSx|=VT~&!FXQ?QVdNKy_in!^-Et@FEDZfSR4@+zu>-sdN^t}IxW$nm9 z^TmU-@Am4_1=OsceOO53SlHf6ejmoEcemZGb!&3V(@3aVdQYMhvbrPjgX*bPM(%pqOY53=fzBY<#o&u&h^E~wes{+8mUd4U3(g6&+J zT3`9V@iwlU6Sa2(>X11?+&oui0kz1t-5-?%L%2LfBlZm$rk~Gnh)rjV1bG&G^)tCV zs&-Juw#YiVt?y1CkK?M)e@LPbCLg;dr(zAXxGQDnT`p^K#7UH8xguC`rwmrhB_&{k&*+Q)Uu zUNJzevashwq}CaoWO3OuU&D)7r&<$}M-C@8X_{HqJ=5shS$R^n$%Msv_gvV1uz8Oa zeMNG)W_uTUbmYLRCsB-ZB1k)tFtX^Fv*5+(Vz@ znymx1X;qtls%VBu?)KDtc3PvxCQ@bEdGgzCfppEFjZIh5zL|lg zU?F6IHZyZF=)D}{{N=-!{QL{S4{5{Yl&cTJE>yz>tg;=;# zSgeUP-A-#KS;ks;iZJc?w(+dh^Q#Krzm>4V7;=MOOch*;>Z}t>3iJV^hwUd9a*W@Vl-sTdfOtbIWukF-sewXp}+(uuY<@uiX^%tz`?> zW@VH7=lJ`cHt?Osc)!mirjhM&YnxquSi7}ghGbHGC0)ACdb4B9viK9_7a-8dbRW_= zS(g`vT$s3Qt4nc0o#CRdVxUir|v0T5}p4Dki__`UA_DwU}M|p-UXX` z6s|s0{Y9^OI1gW5Ir+iim9J=5)}8rkW(VDe{?dylU8<^IER&=zLL}pVL%;ZlM@qs zk~6M_#9WCJm~`3*mNBP0)B4DT{HZhRI`jnbyFj*xc0%fFkkRV_)A7pWA8(k%c`7qI zXIv&IOxHc>Z8sAq9R=~1BZ-u-=+2z=4b#lpI3ItDgMJU(VG%BDgua-yBRRh@_`_qD z<8;M<;9U8ELAJ*S#r?x_dwJ)GhrHTdbUp2C7Q@7Y(qo4&j)1H7g@1HMkl&Qm2FB^GMo|oEgdU^O)ps>FnQh5A~MabW&%A_ZC8Q+N?cN z_uH1f6xa;i;7pP8Y?2OH=IM1e#*sMYQ+D&%ukjP4>?Q4iDp_1iOjt4Y~ zgPqNn*eow+4mqCHQV5lV*iq1l8Y%{w-5kES^LYgf%BM<-blG4csC>}V2cW2@qn#Jw9!%8JqOm*$h@{=L?FX+J1Z z3di>1ciE&14ps}on2>(bA9L0$XeUn#zGe@uOmL{`AbDz}{y4tvwR#TBEwX~a9?amZV3sGzS%5sj(~>2$s}Wgo=?Xf9c?>PxYCy5MUzWI z587F_ERt*oA~g&~NLcdOY$@DF-5L8&ddSkpkLrFHhsZFZx7$3x2hMQHcPQF0JM|>w zoN?B<`|T!cJA$)AGRU~XpOk=oc4!-fqc040h9%qu%vUemM74YlgX{GeNj4fIQXr6>6SK(;n-&w z7(Y+1PTd7`9x9gBahn#&~+f-bsy)Bw#Qg)8UEF7ziLbv-5pPs>OllCQV*_2L(gswE!$10Ci5I< zxd5`M2mtGok)FVwymEP>ErEqKzmlJK zX*BPZN6x_VFnV^bt}hDQT{H@eyK%H`Vfk`+`hWH6)t`dqoVL*+2eN0M^X*FDau!-c zdBY=a>|vRJ1CiAFbM&ng@Z!%dBa^H1h3ng=wRJFjG`RECqgMHHhvnoBLFwzrttdPP zcGl%~vveR8+l=QQ{-fNE=9e%Wm7a?&w5uAVirc!TFsFs*raSlTis?Kjs9fk%t1R~x zpSuwmBY!A3BxOlcf--aSV+B_jV4!rmXTv6G7k2h+w(O?urBIW=9YUP-KcTM2!9NFM z)h{8M>NyfU<`SYccSk4XDgoR^OB1zv207zpOKQM#S~jDp8^7&Ux`tk!bkmi%Q~U~F ze(#v6PzxO7D8_5M`)$^xrCZwEHlCMD<~exTe6=8CgMh%Qk1>cw0hF^3_IewShLc+9 zP~WY_v24+*3n*6At|y;q&+d{kz_E6Sw;5Ilr2N1iU;Uo{0A%}%xUVk!O&+14+4-i| zPqK~P3#lPSarTK!at23y{Gd6HX=CK_uUO7e-o-}k*&ZpQ_>%tsTGtxPSy{(-BrbtmtY$nneNL{b6*y9bwi zpDcyQjpB#QY-A5DR1x(L!(Z8Bz&;%CsM2NA?sbS{4dz^GHz)0KI}9|wd?GnSmsnd^z|ZuG;FJ5%52T z^gUM218W8RcIIo?A}1`XaVdFMc3W``mfYDVb7r(9ZPmJuT=SZq8GD0`r>GU8xh?cP z*GTxgcdk5db#rlW$^xrPA#$V8WaJ9$v`-d!!FM+pub%unsJv4z%v)*B4SH9D^_C?U zE7zetD?(3W()CX___+F3)wR!B>h#CmwEjGr(GtS;t+>p&he6n|a5^GFf0 zQPhk?4XH8}NWb@&5qAlKeH- zuPs{M;@K=!78Rai;#3Yu%V(7;N6xs<&CsqZ&!yFFmtB)nxQ$ls;Q7)P2)j28+lFz$ zUO~?z+P>@f#o&7nh?-Q^vTR$+HWJ$0dKGQnTQCIqOhyhlU%WCwR{0;sJ_b4tuXSf} zV;7f}jN}1^8R{3DVYdUGcp$b-XR|AF4^U|QLmcUM8raDUd0>pZvm!E?2nTWKagn!n z2;I9JCb{s@Zef@oAj+?hfp(HP!6P_Kj&a+jPgCDoXicX#lJReafnbVM;FE*wJx+PY zZi2dvGvSrfS?_CgUoaLMu*^6({DwFlxgG0B?Vy?Stey*)FYckYnQ}M$zdB%JXl^mj zJY>~Np9M{IaU;BI=1NPk2nIdz(C5%_Yw0~h;Z>|wlgu|Xi~xoIu};htGjAkaEw5VmmKaz1X7h{7iKjU^*&c$hKplo z3mwvVs7Q}!!kpj}gZT4->}m^tg*U`pd2u$@11O*$mm~w8!=TPFlA|DzUr5}pl?IKz z);|o!>_pOC>$gP>q5-@}kgn#Fb~b^x2Xhgft+x+F{i1vyqv)#)+FW;6(Z~Y9Zxewb z0EdyCw#d94sGt=A*kw*lRBARpTOMoSZ-Tm=#e{aS>8*8rBxTU!cgP3c``dPm=W}D8 zdJa0z4t~n|mc8N&T?XDw3LRSF5MJK%T(!26QGmg~l0_gM2O#7P%eZnqQ^Fs$cY)2c zk;h}JXgY<**^pkq?QF2J5t)^v++$_O?!u|q-!Nt1C!%;?_L%r>ZL&>vZxY<;b4%yk zwVtDF*AhD`Jdfl@!zq(vGs)$aW+W+P5hi}pU528MIlun^f^7J%!^5_kwxOabX;Y%h z9n#9OiDNkcE*S3vZUaW8Nh6RC6HWUL-*|q`85++vvo=xG23+z&lDU3BY<#1qI6POi z>YwmXKZpJvw6T*%v(@h+N4SZj{qjo!s8otl7;RwK%BX$HtA=>-&J-W;QJp0$ByBUo zH(I2wB}px%y$GurVgfS=r3~jPesF+*_yc}Y9Ohps(_(O{{RqL3GfDGy$+WCVn(@eg-+f;`MMS0=cjywo-5uoZ`I-;sGFx|P?f)j=X zLCo_`+K-hPegx4QMZLJZM2Ov9uxVvfP*nM7geeLeY1j@4BRM!3IrJ?V?j>b!-ZnTS zde=$ePlB=}gUo{rG4h7pkH-T(nZU0>@c#h8w_C9>K+zoj_Z9A}{V){@(o z7$3@{OFhFF$I^f>Zglefx{Aisbar&@nm{HLE@)YtI(zwppIo;%_n49P08 zjoZ6@YtH;#@c#1U8~0?demd9G6Zko0`^lersXxQ0rs|9Mik?yP9|U;A_I0?PX(q9W zp^=VEY^P!W01Em40QQ0Xoplc zZ9m?Hes5K(8o$6zCtI?T-tNZU=H6Z&DPj*B{Z#Q&D0d$ST7p zs8zuG%t#3Czp1|){{Y~Q{{Rm@EyZ(lrT+kFcrRR@Hi&8(b0yL^UaQIhGmsFZLm2d8~$D8eL$W$qZWJ-Ys(x8{z;D#~+6i~VS zr**G~c2|)P6KOhg>b9>G(8XySvpP2AD#16F+z#MWAYH(#Dafz1d^7(52FzXPvF($@ zUK^25U8^$g$Eh)&*1bc)AMkFUhApla`)gab@I zUaP82;*DZksUf($Vi_Uxi6SsT0kQqz)N_D9KS3-EMl=J-lYYGudoH4DjzZ}*G-Zz2G<{jZfN$?As;M_ z+t$BtrvCtfDSR-K{ZyK)ZP1tUoS*!3R;Gpj00e0GTc^J3i1jP1cFIOSMWRt*?GMJE z2}aMi%uYJ=JlC>(HPbAekj}1i-vr{nSG2GAB}c+NH}2@SE=P8`KjB?=gZl$~Fwi#n zeiycVz2%YE{{T9lX2H})=O=?aH)Avl<-@yfI;JbpJQLxYR@Ea}H0ztB&H_Om;y(;m z=r)u4Ie0fiMO$AC=+W)}0HGe(PL5fSLi9rqoDpA{uKS17dt#MFM#l% zU)|#a(x+d6kgxo*{W%rwqfCq+yGiy0Bks}2h0f0!UxCo8f#ok9IOe90;GMWn^(*cO z>0Z9p2tR~#%_rFumi8TK@*(rH$y4CPvbQmQm2TJJ4ZH{aT6OfVPHRQZJ9|@7)k6K= z^m3tdvxl|#eGR|ekAYg2-WW(g!tKX;=?%L2ezeknj`;1~j!@<@qS971AIhamMNlg1 zEO0l6 zw>am}XBqxXTr>;)*B%M0h-d`BAlfHqk`@JZ%HNG%Bci zD4+#6RsJunDPVf`In5MM0s+M_kNZCPqKW`R*S#Tc{Q@ HV1NJFdVH0S literal 0 HcmV?d00001 diff --git a/data/mllib/images/kittens/54893.jpg b/data/mllib/images/kittens/54893.jpg new file mode 100644 index 0000000000000000000000000000000000000000..825630cc40288dd7a64b1dd76681be305c959c5a GIT binary patch literal 35914 zcmeFY2|SeF_c;EH!Pp{25g`=HHe;KyjV1dUDoc?VjBSRoXHAGKS=yvfmQ+ag2u0RJ zh_dgJ>`Qk4XQuVNewOd|_j`T+-{0pkbKQIHx#ymHpL5SW_ntF5?{@})-I}TzssIE6 z0j`68z|IiBt>T5X0RSBx;3xn9RAAj601DB>{r(Bo)Ph(pNMq_pU9%UoE*74Fs=kF z&RLEdE-cE;g>kmRS!11T<+x88sR|*v(egC1hmPwh8=2^^BcZEs5REgO`e8Y)&b)!50dys4DX6{ zb`#dOw8g06T%9bztd2U)73+g@cC&Oezz}3bi1=W$Z_WR_y&Klb0pliY<%+R%!?}KM zLoG{i1IgR|zB>@)Wa$ic31WW3prQU-!dN$qlcKAurT0(Lmc$~+7})STYkxuG9}kX5 z1Z3S&UJ4;ALdp>7{w+sVj#!Md+Yfd?@8@2b7xs=$s$=<=+#TQ{X;`1;yrn;g0nuL;E2*|Jkqp8}_U36a2gN`M+h)`itTH zkfHymAL3MeH*fxF^+Rw!Z}ipb#an__xCPeP250frlKJO0uHTrYiq`h-1UHN|5omz1 z`e|7Hkix$=CVppl|BKe>-&tjUG`*>1zZ-?biTiFU67&D?Xz`cBB)0yADp(?IiK4=w zY5H?>{0CquG4VgcmXMPC4fv-`0`~gnOnIM#Tb}>S(En+n{?5|>MectQ{9mB|$}0s|!T%GY`}=YJh3pE!k-rhUZzBB{U@56T zmtE9v!T*t6(t-8sarcML{+DDIe0>_At_CkU^ofdF*V5_dN&BZZ`%}UHt``1DZ7Ym- z64u8N+zhQ;v3O9W|Gew?+ZFs@*SLQ-M545QW5n4~$PwcPUOfGVwZB(df4KMfLGAZ< zN`F20zwhDi)S{S>sFaW>+z2iuEsl~Fll<>ei+?z{|5rupq_gG!ogMvufQ;J<9`1gbT3km*+hPndEsqr1UmM<>2xE~UBxeK{kZY?yFUIn z%gyayFry==nr@#7ji5_E; zU-;X9k{|d8sg9nMk`i>_H*fGRQjzE<<{)9R{{X)vKgZ}J4eV`w|aYRMJBlZCLw#gK}e!tamfKUJ+d9_L0 z9v~K>XfXNO6R`;F#{0!!c88Cbc&!4ipKBhWGLh;z0D%cMseW8?N~S;-CO z^tJqz4Fg-dZ}p_XP!b%mRv5x5M{NSw9P*ofC`ff6cn|o6kqR2ef1wStb;Y^ke*>Zf zec`rPXN-8Epl+n2 zeIAE%Bod_h4*m@?wH?mYN6`^$`+cSuz78G?CHM*kTf(d{HkR&=ZeW($19aE?2|Dqr zngsqWI_-H|Wt=0<^_R(I{7PR*?MDUJ1klAf6X%-R4TslrcOzha6&xBzP>}urr8|#v z0|n|2a9Z#c+x;G$w3hgZ9RYOU7lNK6<-f*@Fp=^^5hkr;@a zg#Uc$eWyfP;;Mw-tTB)u*KbSu8x?}8>vw*L%kz7W+RDzq@J*Zp0~|@QeAR_uW752& zlBlpd*|=iC98u}~G7n!@#m>+!>u+&2&IDmgYux#-VKrWscoAac4FHDr;4O(UEnMd@ z;98gCn=Py4 zD@|INhTtq>vN3HVnWaY14$DN>3Q6cZB{10i4vLd200qNEf-N+n3CB-mC=3akZF5FboY z#1si)f+?{tF(il&ro>*vkRUZM1>1rtsVz!UP#i7}))G^YTQL0!ks{@h#1sYc38o|* zlmwWULJ;9n2oy0D138DIK$^sq*cOE(wnZU{I4C3$2ZbWxpx~rDks69b4TU21h$7KM z5jjDK!X*Vo1wjl!2|-CgDM6$lN)VhHI9w2p5QK{f!o>yQ5`rKSTuKm*6hwfy2tkCH zAVORaAt8v66hue~g3O7D%BUg~;o>SPqAC)S$5jD1@@2tjHhV z7~w&a5;dh;U`=}Ss4uEv%xxo6>4Bi8<(1;lanQ=0Fb|8se>_oI9G3w zcaeX3Ufi%Yf3a4HHQyK{W~Fs7)>up8_A1A%i?JoGBW^BI&G!ZOwbIDd(wTs_bOqzP zzSV?rL(oAtM#`!WiUR`5bJ|cm*bWa6PA*c6caMeqQi~VH55u<8p z1U@#g#gtdLQ)FUE#OR($b1cMl;g%afeuy?BItVsIPXLrLFy7n;&MQU{$LIa z%XGyMz6E!F!&h{+bi~s!h)E=UelHV|oWAJWvAldI_4Pw`)vDQT3 zSFK0lj0D!kINK88Uo{>HP6|^6w+vUTl?E85=L)j&HKvXvJ|q-FOD8ZMl%S7sC8^b) z6~@2L8j&XWh`@>IH{MMC5yH>qhMGYuECm-N_K+@-^d#m4_2u zt>9V%vtKi$O45W&|E?*C+;4O?LIPAF5}*!|0F{db3e+f|dXWN^3#elepnjA9wVi|* z9L$4oP-`MUttBQ3>RC`3iNR5V5};0$5Cb)?C{Y!G+DQt;l_HiwoeOF%qAo|sAS6{# zQewwNMODSblqDrpQBnxy<6_E4goL7!BwP}V0R1Cp|Gv5aEs8%^6yHq^aAUQ^ODlhM zDH(z0fE+hyR-MOikpzGgr>tx!#-(Fv1x6Wyam&BpV~D$=9QV(SkVFbpF<2Y#uT{TB zjEYMmB*Bjm=ualf0kKydP^3VQv*mYWI9x~+B_t|g1ecITfFCi_|B9vrd(hGO`@uUo z5e*_Dp0tKDc*X5(h0!=J$L;QpwU$0Ep)4kOTtZP&QB)ZYKYknqR~18viK?Pd#}Ud1 zRn)(r2%XM@+T*9aM-10=^~8B#oJq0NPDJH*9_##7AkWN{6K{h^it% zjU^=oUhS%g3WVO2wm;_k!8r?HsSzdDn~7H2e0m zupT&kgpXg~sJMir6cQz^d|X9UOE-Pcd?_R} z>~eTSTztZ{>o;!RO1yVJEj=UiLDs{k`2~eV#U-WBURKw9}#i~t(gqHt-u=c@Yd%PaNMF|XTsQTlSvTFKSOHw?p|lgf+jIWEg7ZZmd3 z4Oc7XMc+a9fC47WkIj@ZCean-{rc%E5vjO!?L4n>K!0<#?&kg{|q#T(e3*?OC z2w>DZ)#MYCF3ynK)s`F?U<&xmiN)7X8VI{|Y`e3_4yVP~nOCL*mZA-H{&%oZiLtc^ znMAAN{*{&y1;)Lr<5$nIPj60+o?FoE#s$0#+@lMd)NX+-WHMK+Gc;ah=QJD;2Owp= z*8XnuEq$Al>7vrvsM9==8}kY@Y|9f&bA*^TiD8LeNX~XSI{-#@H}k<%^fk`89l*lW zV|pN$2^e#{o8&&8l$#uICgPma#9X>z@LDT$A}YNoPQshp#vrwwExy-!xB3@ypNq)M zye3GZ!w0s|rj$!IRjm)Qpv5d3@GQSzv0nWlh`zxz;MOv29*nRXN(dZU&+sJLBTO;cR zc`bTFDdfUwWRhdGc20~1^UQQ5bBxE+1`0m z-q1i`B$}W=#u1>Kuuvj9pKO2J;>+o_r1Gr`t>jm}RGPYO4WuLF)hA{mPWo>7-?1ap zcU7yve!N|`1DM`7r(P$_8l>G;+X2i2;Z*gXz0)j1b^w}y-umTP9ocL^)oP?t&pe@P zsG?9*dl`N8!KlIhWMu`$;N~N(0b6c!2u9UXbQ;#)CCNS0RcLwmPps4;^m_SnN)kfT$I6jXjr>%ld@o@Jb;Q zdt!9?_3125Wu;tA{Q7Vp!fXe4r|qw1E^QQDIyI$HJhmyLPcGENxFwLXXu!;}GO6<- zh4NZ&!BG2MG!3lR`EG#~9+WC1?!+g7Mwd0)Y_ZDRDTAPR1qLXee6PVn=gqC7QqQy; zut#FcONI?*YPMP_b-XNU(A~9DFAwyZ>l6fF;;*Se1ycf5{al)X3;AkMm6~?sj8n#F zx^4bR3+7Gv;%>*PYcHD_H|9A5aLQ+?&JJ7yxSVc9Y7R8Gngaw+>&cL7iefUjthWiY@(Q!FTv*Qt&PXvOJ}vqfhur$k~v_@=%Hk8TOfPo8QVhi<`Pp&S&R}YiytZ!?Q_7&uch-^ zeoS?t=zJF_qN}ZBrPV&$OfLtdlmckXx04$aUj=R10cd#3P_Qaovzlkkz^R4P%O;SF z&|<#Kdj|4Kr6b_^wBd1T^r$|JVdJ=0lp8~6zFJA<%B2R!8oFySHwa5F2a{KxPYtz% zb27W_0QsCheN%4az}$Vndi&>Mo$kkU?J z5EBVZqstn}pn|3NH~SQRsb-B^?_4)2x=7PI&nAL0fQ-7L_)x%n+w&&vt@J{&9bj^w zkHdT!b*b7!$qv9{E;2lDg#+5|>s(2{0~FjZoKNeNkYmuc5b%>r-ENhZTRhsYCU->&Qa5XbHjuYmjmR1v z=)R!X`gnA^WsiL^4gI?lW2$xSgn$s=RPP}jGgHNtX=P&_(+h>6d>e!wQ&XigdGUz!qGAV|Ih>)A;Q0;h|HpA!BK6rCifUyTBLgMQECJ6AS#NbmCh43J{XR*O76p zTf@n1wMwv>kk*#hTb+ltWsjC#cUgKPInth$0`m@*U`1ux9^~&U?hmBDQ)!BI%8Um^ zGaVMwb~}<@rj;zHcp<&cUWF}eN>>-Tod68p_TIw2oy&Z$A5}M{$oZuU+>!v@fR~XI zNYE%!?MAg^%sP$-a&`nPKd4pT_D@I|-K@Pvi1sf4Lec#q4`50D0gV^aFHiSm%kFAk zO6M6%p}bU=JE3bdv7G5R5!-zy2L23e3s_Vi^f9_)FF}EBF~vKhYxK#lEVPI56son( z>l*CunNOkXLM=@!Nzg_;+PctDGa;W)dnOUM9N1i3p~L1qxz=&JIN(Biy0|c!gC97m zxxqm}AxA-OZS7w`FZU6Wy|4pF&j<7NW3ru%_4k8h4dr zP@R%~*98+6*4!3!4`*H5ltlq7eCEz6>-HH=o#=z0$@7@=nIQ6$CWo~De2lprv_-hw zLEL9)&%oLR-p0qyw+clPgZl`YA+Zy|l5^l4|euOD3-mGz%IIB6(fee9GCiCeraORr%@+fKvV{)@(=XSe8DvGt zr$^`HAYU>SN09EjPJQa^$MW9NILwt`{Qyd5h2czI?TXwqnVZRTtDKAVm0wuCBXBL_ zovg?q%E`-bte@CawGe1I)Yn|T=pCG!U5R?TEMzdH*C9cB6DPh_(d^9_`N};_8!tuG z3p=wL3Q2IZwmx^2;mIcF@Pk}8ufdq63KLH2D|`*SYQ?*fbypwu3%2o?$hWgQ6spd> z9rBb??n+wRoJt1mt*HB5Zj(ZoTjE>x9b`8cA3i9J3Ct-S3ga0=v9G+bzf)T%l>Yeo z7ySY$O}hN6nGL0^RhVAge(9jq3Dy*<&2uCAc|E+xc-h`rA6DoIr7+8&XIt|?<`~S@ z?9cL+d)(y0b%R@g)_G%M?3~94l*@g3fc?Wf^U|}-W;)0Dp1R6_*N?+kAvd^OmKx*` z4o55^!-oz^59Lg6wg@c*wboX;-PJo03hS_@0oJXysq2Pbq}un(Uh31!(y^316=&Yd zI9^|2Jo!wM%^8s2CK`lq!ts!dK>NbCSFM3yc*Y6(irypsE_;@yax$&8ij zZQmv@{o{9#TpAOBUWXOxSzh@GN`odaP-nQ-&O+D$I-510%ZwBW#F9*~4CgsIYo}=Ys8oDtO%*QyOEhKJ;CdLc^CgK8DQC{Cwg;v(6@vgPyR~ zM$w~Q9`i*9DXcf5KoiWjoOzT?#$Q~ZD)Zsz%?H4TTJ-Z~el~vg;oVoS(_!L31Is<%2W{*@Ayr7eTwZ$V9U#0aD{K=s&(v#K zdBDRCDx{g)kJ=I_7?-&n)6yB3yuKDh*69^9b*6kI}U94Gm;8b8SUtLR%(v z3}Y4(0Z$o+MZT3HkGZgE_{eSQG`oJM^JXE+$^B}0l^Uz~=sgOAWYaC=>H1B$1r<1* z7XB=H_q=XY%S-2}EV!$isZ|i#P{Ry@@AfunA2c0^?$)BvIV1!7I}c0c)p zk(nJJ|Bdd2t_-h7b*3UU8~fgg6o}CG+7BaBuAC|jGF{uxdE_i!y(fesysExVUtV_y zxHPVn5WuUH?ZB~iYy2c4XVe^mvWk;G*-)!1nih5E*zokL=l5%^Uuso5y)0_;W~aZu zAP!k}PM#;|z2zz>El-7DqH4lDpY$tMpS7O!rgS#Ljs9H5=s#k#jt4+DNBgu35 zRnO3&1()EHiMx_K_I+TnHPV+VegtoV&%D<2`Cty7QB3oFS%;_4mpS&dyEmh^WtXp5 zmxU1GSzO}6w$gdCmX10Yh16*R+EKxL(E$gjU$DESOw-;?&K5WmyVMMbwkv_hspYJJ zwmU_EEPI_~oJU$TPn{WoMkfeV`^!}%0C&w9`nD71!ivS;u?4(LGwKWe62Lj7KyP=8 zy00gAS}EhR1WeFVN^apotC~14Sjo@#T$V0cPT=O+=}l}ka)Q%B6FA8_t0KVW=br#A zV9?V9m6tH z){m^d6TLch_6=SPG?wM~p>xpMkK$+jH*`Ao4Zq8__R0vKAMYvQl}3(+n(U6MNLfDs zZnZ2^d+pcnyz0%+;ETPLs32Kzscavz7KW1bnXf%Io4YC`@_DeIe07>L(yd3XJ0SYd zeouSqy6&N)@q=#Up`iu=;hFwkxj~z%5nrOlYTq^>!sMzBt3fBL*2e=;)T1wFdV7rL z{8Y}i=U$4e#=Xgu^O~Gonmyh*wU%#f0YW6Qe;{wMBH|I$;@crR6H{K5m^l?2sQv|81Loq>$=QgU3W!b~a7Z{QO*{D+0 zlPi=j>c^1vAD@@Bw3{=u-`09s7}52nl=hfl!yaabLZLz{1AH|rYb{=$7l3ItI zKCNRiLcq351JudqiOm=6934f+nWOtPrwLNGDok=^!d+1%U>VSLD;I=qp#Af%~ zN{3u#U&$_aey7iAVW0D-lT>)GHdFBeW)$TwF2WkqhFqoM^ZWvazHBn(>bJXFv^&;ldbXb@a4^VzZUYPBhSv>mRnFrp)7vQe=K<}3(1{$de^FLIv_6Gdcmi?D|T`= zLZ_ne(3HY1C~gOMcz6G38VuI;QQ9yp8<4V<^)DFHxt9>>Y0-Fj+v*iSe(u7TbiU-w zNu7i6t%1%POhC$_R&W7*z%@}0Kkwmd-Uj@Tv0L?1jmzr+$*t`{wXV`0OBtcPVw#j{ z_Q2WRn8vhGBd@+~pY|73Qwppf8aAcdEygwjOkEtoV_@K5l0K9bAF%$8m30|5W$*-= z*`5v$Mdv;TFLY?sDDrwbtg`vpbfCdgldk~4mKD@WhS&0DC!5&_Z+zx8VZ_@QdCMK( z;xk{-fXkeD%l+v%>crWXiIE2ad1S7$QWwPogs8TfiL3B^dVRTdLZ0fo&*m&M){W<@ zu~Awripr?5ta;Lp$gJvu z7nhO;ymttSTRW)^6QpW7H^Swb5~y_D@N_{o{fKTXFBvPNv0v|#8+2s}A-!iPk7m-A z2~+Xz*3se9-@jtXm899mc%Yi+V5ZP^E);q|wJ z?OVU<))?wf#m?ma~LYmf=N;xO@b`k zl8;zb%Xt5>sI^bEwf?D?Sv7xnwV)b8^lZv*MLMm0G)i$|uPj#?hTNm$b9$!Zrx~j* zYIzCe-jqFjbx&m(P5Q~AIOdB(ifbW0!m;U&okHCwH`J1bPS^x0<$T`F92*cdXKa-t zZ{YO`@S(qEFY@_X(VD}a!92^-AR9V-+Gb5`t!{e~J%#z)N_jtL1uc7<8sHh#5tk;f z!54oK!WaEH;6UoV0?|(nbo`xpAsBq4hrpZG3wK@dYK{wV9h=*yTs*7`qecVZhy6+% zBXe|n2QQhcaarM6js(I%m)@eiWFqdeqiO%EQSLW8Knnw}y}6XYIM+EUp^Xed1pemW zOin;`;-woasuua?N>O|28yslSAC8PBP^kK!4dc3NB*6aYz$bYj`!0@yb63TTd$onr zPh%K(U36$_-;BK%wWgxRCCpNUvm6ca%HD5=IhyOVm=z&Y@-nmpiCnyb7)#R@XlS6T zPYpk(0CO0`MI5NRW<=ekzZ8fFvGp*MX@-tJHZ4GUWR>a>0=)b z?n-iB%4JwRwxMBC5ro@HcUzcrM0@(Ej@ZHPJhh6uA9R%-+0sbO(MQ)npd+YpYp_S~ zs;Tzoq*?D)Bi~JkeFdP9xY{(@*ln+jhd$sb*7eTIKynrfQ`e-W&3qkBkA%(sR4{Ku0zyy^<+v{(mdDH+S37kw%bqAkHyqQVDbC|LEybY3d}Zi2S8>D z8*j;7Xs71BwqG0=lPR6HxVi((okEAi`wDGk7fv;K_0!@Qndg~r4?e7Kf3!Bbn~+>k z$Nz$*wqJBVXRao|IeNM_V9r!8(9mmkF!aVqwtQb&`J8q$eM)_)izNzBzYxpRgYMf~ zmUJ{Pg`r_B*93FXik|05YtJF5 zEi_jpQtkEY4)!)VHwW;W*)EFl##9nm(q8C^5%!&3>;E95XgA6(R~!DoT<}8Co2CQx zsb3DphU_`Eo3QZkp5Is%1-(9n%ajG}472!UwGY!gaZ`TxHH9N@oVI^X0FST}@0kKA z$IsBE@UCRN^PLh0i`iq!dSeSUVSEepZL@FmjvivGwN;zh0?SLx#Rg zf6Mtuct(-J?U(9psxH%EAv=O(Q36gKj`p`&bSYGd`#6!Q8 zlU3#E4a@SU;ZnzPXnBqn?U!-oAt$2*B;}8InJ?D*)0CGNMry5~?v+Kq8j2lHr0Uj# zvzk4Eduu*uymtJk=F`~4vpYZyTeFN|%=A^ZV1}8Bw_@T~@Oj791ZgRSLfE2>DmwTe z3H#QQU#=H4$T)^1N)GHkepP<&o7(2rH20l&E}ES)fgNVO#DQ968>$Rni8B{BxUx~D zEyR>c(`M*i*#-~8J_sIe$8i;hcD0-!ca~s{E)!%N-qrj#rf^mE(t6lH<~fHgX1L3!9LfUfuD z+YeLF=})H^IdAYrw5fgqMzvVAg1gh z98|~4DGWK;7_BSDV-YGP(I@HS1SqTx5L7&oRxEY)@#Io{xp({_z9%h0a9dX%XmGFk zt9lkWiunbF_XRcJMA%B8(f6Ta=;O+WUS;;!`ll+~vucK=CVE`Ky^Nqgz|6rA)KWg+ z&F5+}3Ota1;)i@VSMYIrOhS4qK+`pMp&8&B4A|dlkeiO=%#$OI@Ao^ltN}DTq;2=M+D`+~>V3~C#O@${n zb4a%qjcyFwk8EkX`zWSGyb5$$N&7@qEP<-j&)=$!nIkpRY|0e#?%H(W^W~@|gxQ%M z-P^;Z;T`vDKY1N(e4XY=9^u9A)0WEv-tPpWnF0n=(NCG{K3|!?=fu}io?PwI__MLi-`Wpbj9E%^8-f2_hmFDmp>A(isdVpOK0#S7in0&-_GmzDi+%aqMlR5rH5_iP4ozf9=Wu;lP40Mg;Y0VWma~uE zYT1rgh8=hu_lZm(z#}_^g~xeDRDGilcA`DHgn8Mv`VrT@w&3y_`q_z^ zbH1aR%Fu<)(!=UUiiMoJ-)49A1sA;v)lwF8$=8T&rx za7w%+WPgu(?&PKLnB_PWO~`o;bpk6dx3-yw6*NBJa@PqJ9>IN2&xp*96-KC7&?Icg3-@+t%UADRu5> zp#~Z*v05~e>?zga3-OlzZAtvv+uBnWeC|m)&>-K zJ3!vNLOp$B`f}I4CpRZ-OZMuVS3P{*l_EK$9j9#Q;5=<6cZCz1^xPL_oVH&bQ#dZn6c+3G zilu5TkKI1ZvhSH)pqvWeB2T_*IaC?;?Ar6@(8T;t*Imoh+~cZL(;nsNd_rtgJM&zS zksb_WeZIH)z~?Uu4K55fnBG1xc=$&4b#Nsl(2#BK!V^B31iY0*%j!jlTvdS;Rb_Fc zEQ@L!Pslaq*l_(a4=pOZ^6uF~yskyEQi#``_FX5m<7#QR@Q)))&yzs{m5xfPkr#vk zmrQ}Fj82=-!QP|2>N$Z#f%(Ajp_A$g4fHh_`0OjzL)Ng>0&p3)wR8zvMYpzQ`4z=>42M( zR(NW$s(ZMpDh=!VAtT!?zChTKlmp8#D-yM7U7ZdY-8GB7=G4lHmC{4T@#XNp`sw7j2Jx;NU?Sc+jx_WkW6@y>XkmVysi(; zD=zjqaqS>%Lczk3(_VVTqBc;>ntf@iuJgW2fzM)nv-$}4&8FC?_%(fUOj&@b4Qjih zkcSTR{0H3y-Q|F@sC>e3l0Qd+bKS%%)|ZjPg~qiJUtyYw=!-Ux?vus zT&h8ik>KxU66}Ku)A?>Zk*`j|Tg>U76?qAIg)Kf~hF6#F;HsbXjR%7_HPjSA`5n_U zk{XVzX*>>hbxQ9jET04KIX5|?RkFXg;J+!mXmfloT-~t8zZ~9=3Bqth*<5 z!D%Q&e#%K+x1-uwU;jwUZ35LZZ(}a@+2}&*J#C$RHnjLIhS99ZoaOgVDriE_AjGvB z)22PTwYWviX%*(VKXJ&I@sw8T}mS1E`-#JXK$E0Y4uQJGv6E z$Ep*Tx}lm(m)I*sM!qnuwePxsnN^zRd!3%fWC2CGdY32D7H8&`-xWtMt#$V5Sc!LE zO7soCF!so#AWrk~!xypY25(c{0j`12mPQ4s$Fe5kK{nCD3sTX;SnEC<*JVb0P=hbi zF!@E%)NpFJtTSrXZ7REV@iW@eb|Iwi+3_BuE2u%Qr!|(o4OTr@)nB?f>v(4*JFN8i z2q_uo2<=0evlS!6g?8IUU=Hi`b=T_;vaeb10ea2~haQRU){`^wZ`t#Rr|gktP}ekj z!h~l$z6O2Dxl$5<=?yc`nl-1W;H+$=8?0C)T}MTq#Wb9-f0)U7;()^AOMr3l+1I9= zFZRcU`WXr&mv*-S@1( zGp>21{>_w^D}D3WyZpHDFV~*M0izVaUC3>*-XYw(@|YJnv}n z>y$&W8tj1Yt2F?|;o$xP;3%ETH^S`(<9zet)=@3p^3|f(k!Cc*dz2+vl_RrItmJ|4 zJI`@-2TndOzWXdYUe(JsVodg){L0|Q6Mua(*4nov){M!o8JIVZ&kuZ_%&s0f)h?1Y zXA>p<$ib5FI&${hqUp)#)2DfPJA?$gCM z_qI^v!?s9R1?O7{tu0|WbhYG-JtkCxab0F{TG{P+r)nOFHOXAH!W?C6ngvt`%uH7= z>iEb=Mt?$n2yhn_WD(`Hg>Jn6s0XkQl>sgR44YK_bb0N9)0PnfdTmbk^TKaHayt19 z4)M67T+>U}N+)1JOA{|~G5e)Gk{zrbzUr{<^brgyBjtv^GK3{d_}UJ2XXFt z_rlXh35^Ulj`rJlVl=NZv^5_)S=i-hFy3Ze610xs*m|)fUz>@k-}0KaxQotf*<^O8 zLLVp(*t&wtvF~T@aE5HKn?B}dnLXq8K1kK;R{cj#_9SQ-`Bw2<^2n~_nkj55hmlXn z6~?2g0iK8nT0y^+Cr9z#tY93(;k5wsr7_3EDBU4l74Yz-Po9_LV=xzxfxYY!lR2iR zTxtgu&p|Cquq|z8B=ZnjhIhg9z~HB;?uXC z{@^StufdCRFxQ?Vy6>w#!jo&_L6`C2`Ke(uhXKxV^e}5-$!E8scc#2guXRhnn{S2& zTuM%gOd3pDKmBkM*vfL(3-agNFZQVaO|A_!V}`7?e!z6y zSRg|;n#@=6?Ka=)sV>vjS<8y{nH#Rs97bPG9XuH{7+lYtFOOU`pQufS`D$bs-`DbU zIX~;ff)vwX+*8L9(CzguU{JM^8*MAnU0G{ETR15q{`|a?PzGLdkLj_CGE68~G-~tM z=avxk+=&>@)sHXp?EKl{kjCC!%}hZs*}!{xHzVEH=?F8B37@V|w-(M-!>vO*fbXMH zc9cQbbv9!T6zjpt1B(`no0KO*RJnHJf`C3D#|h-PkaeBQ3H{L0>jg9%H$&mrv5N&2wYwksOt&pR%~ zq+`b!OVXrbN@)il;uPqI;^T`3KS{TmgB_NI#2r@mirH=HdbJ@&lvwJ)F_}(YfHEY<- zb9`M-7toB`U@$`6CD{`f?uS2KChy>pH#l>&|7o>E!6=zkf_lZdtImro%=9n!)Zkw} zr`$MO{LYOp|Jh!XdY&l%;rcE{IGNCf$ytxG%Qtfn57`7*ZE-1u;Nj`Rr&yu%<;WW@ zk1@(jR$~IYs?f=RRea*&bHf#<=npA3l#{1Bhs;?W56oTOEJET0lm#uD$R3~BU0}}m*7?`UY#(;rkn4+_$A0EFKXB6{Y zaE_wvMbQ>=EhwHpEokzT3gWN$b~)?ugr2ZYX6$E07WB>fO+SZDd8XA`nUd==S&|K8 z$@!F5b8466G(H1Jc0#K@tSf#~`)!Un>!He1`AIS_TgeQ#Qb!gVv>D3tWtQuLrgE z1T<-v0y(NY4p>(fr)^QS|Z)Qhv>Q%IF zzm-h6veILsbmqf@MINstCa)7E6s}gM4n|B=s91#c!=62ff?l~iQmhhn?H+vk(9;n8 zbwahgyyXi)(=xN{V!AdJUMF|LMxpu_C-2*=2Hq$^MOlmnANk&^baG-Z+wCQvtUYyF z5Nbo5gv8iir_tJDmL*F^*+9;jAau~II`LZ4d~HOOH@o&+<9=HF2b2b+ig)JJ%V|Db z|7e$I`lqVCd^Quw#}N^cRTt^iH#?j;0)wu)XgxMRQ_${ri>t)z{J@IGA;m+{ZD!;n zp`QsC{XRlp%PJ2(ka#q3#Nkv&-0|1`2CAZZeQ!oSJ*wv#ak-`wrCd-%dNh;CI-VnQ} zac}@ zK77Ica&;qq^y{4tG7oSCQ`et9Z=cP^6g_LpVK($NsOiymW)7z8RTgITq}w}ES#{#>i%YxGi`$3G%C(LmZQQWV(i+DEy=5-uz<4;}OSv5)A8@st za}mDA?ZeSFC%irKK3fJh9AdmG>O(Kez;2B{HX1;-s;;Rl z+vc5DV;F#yO;&(>Bu8D)|K?U=0e1h$>Cox%2vuzE{X-g@hYZK}*5R2q7WTIV%CBok zCf6o0tjT>s_VBEC1Qq)mLAZ0%2*WuSWj6K5rj7NB&NgENPve*qpv&ToIN}{p1 z(SnXj!CR1}T*}oO^VcV}yyZ1jr^`dOd3xm5FA%R6)|Cf~tFt-l*6%wc7B~=WZRLCH zm>AjH&&Aw^38q|wsL>IU9vmG|wM!h`H|fN#pZ)N$W)n@$I|Y}8K|HeVAo}R)>|k`+UHIZGcqK?iRmCd%s#ZFYV0q5e&8P;?0^U<-QiHk}IIm1}mcd~Q3MFZ^+?4_MQc zpH+-=0Ye=j)6c@9(w^=BDxpZ<5|0Y&HvsQ$S8tZ7DKmej)ud-3iS@;&x+1FaiAM(5 zQzADrY(I}+CEw7ML_69Ja_iB*)0{{Rbg+(7{xs0Y;6mp1WSn}wB1$jlgm8Qctl03Aqd z0yxHdQ*L#2gyE%F<80xEM{Jyv--Fh?bHQI1?e88tdxu%BqarsEOS&}$dmaGEBdl3G~7gm z1X&$M)g&B!eSJBg4qL~557Yd8r!}qRw&>*DBCuelJg(42dSn(HWDtGp=1FfE|w|?B;u`2uny_3YA6^40y7c8PF;Asrcp~!duvH6X=G1{!zWNa%E zN|WFo9k>zeb}Z2O`)*?RXmWEHR-XLe`d>uP&_&lP$*B?aKK@3YH_L z(s6^h5H{c*bJwOni?!K&I+s&j3tuKkeB3&nvB?+?y@@04=NV#l;|CRm;Z1!t7@=)J zAp1;a8Rkr|X!k3(n9auJE3|?@1Lx#!004F3OIf_TdsVZzX;*F+i0%|d9N_Xcuwq6@ z<8kMV^7_w*qSkcd3|6pQZ*ErMK}nFGyTRa--;s>+d+0u)8cBC0y3HKOk+~j047|Gr zEO_LVBdJo@?UPk+e2qb`BUv8KIkM~{C7c`*;YK%g&Ozs%PZR<1kBa{Q;GJG1)xIf7 zCxvg~(DVz(Yk@WVau2hN=gUx_tc|$f45;anxdOJE{t1iYtycAR{6f?24XOrTJ4cyY z8+jlJSP{kv1pE3|QQ^;zUu5{5sCahc#NIgZH;B^ecx~-}vSN}r!k!Aj1YUH3fCqm3 z5KaID1MHf2#a%>civ{kRCY2j5{{SNJ+q}&A%AmL2CInz~jkh*1#AHwh$l90eZ}CgR z*No7=hjkL8XlO2Gj!6{rkgpCf2t0=DdFxs_N5DTGT%>QPUqhzvZdTMZuGtuDOX-Z{ zrZa*&^YtFD@!!LK9k*F-?xK=lV|f-gifo)JJGR66umwrS%%ZU$_JZ(j{p9m5t!Xry z+!80$vj#eHY{ z9(*;MRrsxU;z@*UYoXe;v~Q2Ub1$c3wjy$Rk1=cK{{RVVvEMC}I&5jGYg(aLwCyz{ z4J4B}$X41&P(O&0cm$K)zWDel@YsAN_-CbPZE!B`?c?)rZa~1gx^uo+2m2+-3OguK z+JHK+Hk`L#y`& z%-uGuE!k~SIHS0Nb#vGs3qC7Yk~?4 z=jTVYX4WQKT?%{nySaUmrJIJu9B!bM? zOrA~y5_8mWYtyCido*#!Bug#yAh5a-{^eJU?N%FeoM&q}KZzB{{A1Ff_<4UF{{Z&4 zk*Bn&1+kuIl1oMjxnf`#=R267+AyOT;1fpE!~*OzcN0RzS(YY_`M2BfN@1EIy!8yB z+#qGhIiL?$(R@>Je*(s$;bcNfN9O+Re(+XNxd|8;EUHF*F@xB`(rUV8ypJSz3krGX z$xY%H83e4IN#JvVoDgwdLEy&!0E8}o^sn?8u8O-&a!D6`H|&%EB?BC9VB2wzO7tx! z#~N>lw&}N1T3XEU22(ZA+jO|VDw~+P`ygy}_aGfzrm2LO%-(gz{{SgrAsfq#j;hMc z#O~mLy@@4w4Ey_^?b~&aY?kt2RE8m=Z2jTDLO|!1KD>3TYqYh~;__Q^Llh%)O>$Be z000O=oPWK5a92DKDAK$|7oNUIX;phIyF(i{sAd6Ja7F<^p*4;uJIV2K)Ev&{|k zEf|F+Nr_H{T{3aU8%{Da!5FWcPy7>m;pUO?a?ivG9<6I{8>P%zR5GM`Xt)HBO0qMm zXKQ(C;Q;_`3zV;?E%mtW?$Y8lMLt<|opX}S^75O*ag5=($4>P7pA5s|%`!!nIT%g4 zA-XNYp5z{*Ipd${Kp5W+{68L(B26FeKEbAxM3_;}tM(LU_f@B@ZY+xSv&No@4o&6D5o73co|6ZHwL zd?j$I(R`(htv&-trx?!OGmJ6$Ks$(0@{vFpe;O@U#d=hbrli)gLlUSyU_!b1#_n)) z#sC;2qh}(#x5FO}E&L^H+LLNc4yhEPTMOH!j%-K(7-)ln#~B1=K?JUA9zWVd%Xrp5 zW|5&sXuQqO84NMNRXnx<8^Lz#bsFw%0H1Y-YQ>Nr^)` zA}{ucF5t`+zU#b!BqZ%p_k*EapPICfiFcn9z9HM(Gqa93=W?*amI21 z81sz-;zXhan(iBixWh{u#6fGxRz>6G8&u^&m1YEdqY9XF-}o=$TsP49a>zGGqFKJo zRijKgtB`TZru77nIKu)8kN`YyI&kU1%)-XWge-I4B6Mfr(en;e`DgFh(a2iJ^&Kp!6bKmP!NcbB?r zje7cP#J4zxIh$-)$zV@VK~aoj8Ej`Y>XQEe!9V;1aLe{xK20A`+FssUc^Ia5{GrC} zthg)aPdMveNl4Z@#hfi0F58y#H$3ER1a1TtB!je)G2MG$Voh&EST5Q)e|c{X(6N^J zRP99=9Wl2Dlhk127yur#{{RH;{gFH;@K3-NULduZC>tVrK7 za6X(6aLdT|=zgy7 zkBhB*WuqnKt2Wcbk+H+X-`!#TdJ#Yy%H28V-kyi${{TEyXwFZ6Q<|Ah^`Hrq?LBc) ztL^^)^;8Ji=O0Ru4p4ObO#o%e*az1&U9x%U#aNk7ztmNEKJ5TkcPf6h8#d8XOP{Z` zK_?UeT5w43ROLg@U(T<);QpUVf>1x&{HOyi;c|QS$u&-Bom=j%+zi&F%Yj_)#BUX+ zg1j?#s0CT3jdt6{#{yYMJ6G2n5^>lX0P-K(16r1U3onOlR3{?} z?yZtY;w*9r2jpKHd?2^@mRC+TDuX3B)dr@KXTgYWsN~3QPD>~(0P08p0hY-pCqJl1z>Rwp(n8)`s#pLE z1#`|vs66$-;MNC>z6a^PJkun!@dk&a_=43VG3~yEXK3?<4vcc@cH;zcFfqsmfIeXO zZ}GWqboQFwPqH@9C7Rwr5PWBDTp(~mW9Cv66MzUOJc7! z0X@$JqI_?YRyOHzD;Z(Q%dyK}Bw#LUr zUX0nv$W`cgVB1Ej_%q{2#EB!iu<)*(q{g7-Zyi?QmmkI?Dp=zf%7KA_z}sGf`#ksq zL-2os?Nd+t6^^ZUYJ&P}csAUl$P}2PVg9XoIor+$Pze?9eh<61iqUQsJ+1Q}oz7%@ zFeh;tIOmRo^&lP@@NeM<#BU3(SWi320)G zWwn`{u6*Tw`$!#B0uLh@CxTB(-fxKdZk7y_NoyL0#HFNc=YTLa<%#qJ55|*J@pO8U zZCMsI8EgaPZt6i854m|hi^hFu0}sSLBE8V`(w7X<1q|he6z2mRa(K_ac+Gs*`$71I zX|Gx)+jAbFBFE(UrFbE*Mghi2L6O*;b`|cPFY&#ty_L+W(KAL4#tM!j=YszL!baig z&PQ%f%l`lz{xIqP01SLLCb@RWcdcDYVz9R2G`XDr010g1B8{x=kX2bh{t!Ss!{S$r zG;f7AvKgcjU2SOMjUx-Z<+4>t1h_8*7tlTAjq#SJx~Il0WFGIt3j790B;&mWSb4u4eMvY-;Kc z&G%c3iU9BY9VF6NDTXlncD{b@F`dP+w<;D%a7%Jnf=J4*wCLXwE)1|tk}aO6BNT5g zP(0n`TMjnoX~MDRBX;o{GZvG^d=H~)S_4K|rDX>&Ixlhas)oSE;=F^@9D7%xcyn8i zTRu};+S|@uRaXTF##C(J?Z}rIE5L3~0)Rb&&&76rBP}xBN9H7P2gb}=$GN<=RmaL1 z7`E`cUpP5tU^-6&>sH#1v2Pp8aU_iG4a6wRF@i@9?UqxS12zF8a8FF54bcIfS%|I3M;z`ojd9w)HU>PPZ31?A| z0c?{PJCu{3e$w>_^xZn)Z!GQPjs{d^n&~{oFOq+D!5MviV5A;#fw+Kqu94zPdreLC z3(qZV;fs8WDK3fvsb@I(1_m*=s0@1a=)5W7Pxx6LBz3pk*06luY^YQ}IW5yWk4`WS z4muj?Cip|(-G1&Shf1}M2F%CExhZbs{H2&FF%|%}2^n1RU5AJK7o_OcUuUtkYgp~~ zNb4Rpv1TN&Ip^jqNj&F0IiLzQ7LDacGOJCA4)$U51IJvCZoRQxx5BR&w}bpc3#^~G zlbh2s!>NR&IkNlIXTgApPUp##{s6Xw7Ly&JWA=pb1oQ z_2!f&N8bF+2)G0MY5Pd(^Z}M*@9R*Q@_!C7R^2mDu)@#S zlc)o9GX*Epp68(>rF&e_6@lV@+TkAi7dX3G3s~a1LE^d$IQ%|@RWAW}R>#5@GYdO*^Dd6&oFU54SuApZa_P6w@QXj-ASSsrFn<<3gB;#H934nDk}e$)Zf zT;DV?x0WMkn+S*-k@Iu|JoG&EsjjWO`&)$l-MXTAY-Er3Z^+3V_~7s|PZg7->mGDU ze6&}IRUiGVH#`&2jQ&;FYe^of2J2?uhZ~~CP`(KnJ^FX!r?miZ_Ti&~HS=L5)tQt6 zp!yI_oS(DRSOwRzu<0$NOsBh=K^Pt=2`K^ZYgCaV<;08N!&~+Vv0Ll2} z;J7?n@Xt`UxP@d{m4xx|W6s=lY-0c~Uw(R5uo;#1L zyaDl^ZBteQ?TFFt^5kYuF>RX}3P)972pu}?BPXZjpM`uyb>SUiJKJ3EQHETQ9A`M= z>5zL6E9tL-UmWJveQ#?d z!3xW8*hU}%0cFQ>IOjgxR}0{;+Rsi+Gs=?s*T_}<&r-8so@94XtkE_${pBMW`LY1cF~^{a6>Rt;XzEuzt$lun)&gd&D+Yy0)C&4bW|*5=(h` zacw+`yMun_DCA}@*9Cx3PEPD?pbwrrJFIB-R-Rk#n^d0A%9;lh|bAJmb0P zZvHOl?Hp-sVQXh3WUl!nGDdU!MMnhZBcK3hjzC}2&j+p|9U-W;5eg#s%4b9=}MsLldCV)R2XZYo!*hA#`Rw5V_EIXw| zbl~l72+0JT_QxI0V7C3Bbej-~Z+UMTFYgE!D$MPWa$Bx=9l8P6x9*pRzhxf*=#hbG z;NJ||?d=AYZx??>L0vAJ@H@knC48R`=`oc#R!dmagX%b-55e6_;)jSnBj`V7(R??m z>z5K0B3msoNg;_q+PEsX_3hYWy>Yi63%(}VZr|ZA6h|DLRdm@%KAi`!ujsiv1*gUT z04^|aJE*Gno)EQ-2#-&;ZKJvh0Q{N#hG4#j;aIcPEcHze{uj88=Fa-ec{7%6(NE>c z3%17d0yLj8hmY>BF27UgI^$f&V-z}!H(1*hs@v)^O=}W1La1M}7G1%P(4kNq<7gti z2ll)0&Z)2XA?@vaGSW|C-gTwTqnK_j-yGnDA(6P<8v?8VI1H*+;yp|D#D9h}{{UwA z`W<6NlkQ$#-s$sPeX+qOEg=phB<;X(6CrY>jL-+z5%`u%T`$Rw-erar$|P3bXO03} zBYK&VHN!C48$1!mP&aKghSXw2`#VE$$+}47S8O&iyMh4i*P%p^;jl2JtK_XW_U!m~ z3$4Y!jC?Z*wL4OG)NP|-xQ5^>1W~xIKp5SU*&{X6cz^cg@LsGVve5iN4~Qgo{{W8- z6LzO@F8)YWw^wtIPnm~q!hk-Mxwu*5UHn_45|QK>-M&%RC-Vij2H=gu5%@I8;@=uu zXli_uc_W-{%SRHh^B>T0(0lFr`Ok!OkBk2R8(c@>?}xg^kEYyUe{6WT9X2S3aF)k) z(~+F*XyV*+^N#-jQ~u8%vse5oI~KUqwGZu`a^XVHb|6?+9d<^eB~W{k#{(pq0QEZn zvD>Pug;fAAJOC=>Uo&z(ywf9uD5+I*%>X?(V1EjgcLDrgg+@Mi`qP#T!>IhI0-G)e z-TZ|@>w(wt#aLy?sAON8y#QB`4aYTJRN47{6=m0@MA3hrS^%bF=b-1(n$hFw=~jyU z?tb?jDmfFcufKW#r*RJ9T#v+G46lfE3yWD4Ngc#&j=h*3wc8cQs`nOaDs%j(1M+kB zvhb$4`z&~pOXQ8$L6TBd>6O|meDqm1gRywV>>jz!csyp2`*7%Xx@;%yk;@@<9!eEF zlgUyu?l4H}+lu{T_=)g;OY!EHHMRAuQ^PWqWOdxE4_w#6pR~995mUkbEtcZ{0K?Yp zpj^nsv1!C4ob^w-KL98K=Yf7IS?ao0mvFZc+(UnE3~Z@~#^KIzc?1wh#z%g&>Ru1l z-saljt>$};#IogvULlT0C%*?Ij!EN+_=imWojxUg$uPc?BG0XOax0HI3FBPNWs*0p z(~x6+qzn+pBimy_{j*!)^}@$^jVmiRnH#s4Bd!l0Fa=IK`udsx`pe=!?EB!~jWeyr zt!pCRhz+mWt>rMt5a%F{RQ3KPU@~!DU*n&EdS}D9w%F^M_0q#8_OzNv>?-fNm zi5s#>$<9H|dFjcN4%JCguI5 zeKF4m4}c|;7zB*%3IOVBdrAKQf;ISsG?@O?aQ7GNxt9@qr1jwaee;U_a_~>WZ41Df zBpQB~r`p+B+NeZ|7}XU(_5-i2be8%RpdaejGy(XF;(z!g7me(4_K>WKr#o0S8y&d8 z=i0c5{{Y~fpA$R_erC1O@7hSm-?t5%ejdK{`{r~DQa{y4kN0|-2Zro{{{W9@Pq|S* zAAJun1vPY&%He&!0k`(UAJqKm?71Mk&KOHqs5Wya$@VDY6 z#f9YXiD$R)-TcA;2-?IXOiHQj6#oEI`;iulZx7sp6+X#|u(s;B`A1%WPzTC?v7hYI z;!Q6`)U}ToS=?#YdZ~sxeLhuGXkv&z?_B};Ws#+9spKL?4jUnUhSUBYTH7j(szCmA z)d(@)sp(D;=cwXEnKB!yc)rE4sF`ufv?k5B7B9G$npF9+M)Eb;h9!tu_$iAIxdAM3a# ze;n6IV`7?huaj>YTf#a=8iY~!0gAJYqtczT6abOOvw(Oaq>eqLf1J`PkbT8Hct6g7 z6&C~Ao(322>+4EJ54|WK=h}cApEn#*&d1@qwcS6 z)iieM2MRuv0Y=wOXrmjqayy!^uwZmNRiK$4j%liZp5ICUhs(j~x}6$ip4Bej%cm5G zm>$1{00DNNS_E1CT+@*-Pu|a=siKX3Y~gt0j=t0Z7nh21r+(dO*yrx?oO4P>2T#g? z8RhLi;ZIoq0G|}#nf2>V3A9iIeqIJK_*1rXO3~-?r3=sb;(!rUZ9e^}N{Gp)WBc59 zqhl{a^q>S#2_E$IO!m(JPzTrRP9NvmfB-r4^rr!|Vuj>)qz4^Gxd8U=!nkb+I#|!lQsWbj@MHB!L_~Y=Q2Q*PY z4prQKl&#a|{x0-UKo3>VG^2*c`4mw=13usL)|F01-lshK(M13Se)fCPqG##rMHB$x z9uLx=w~l)redwZq0zPlAwE;(C*QFFt1DBDXY}1o~Pqh?K08UhYoEl#&W54G`6ad_X zIM3%wC01+$-yJBTfE*b~?fFr?N$W)v0U(nev!q``vZFwZ329y=P|DnziQ4th49L+56ege&%HSWEODYv5vkD zfQ5wxkjC5qCv$*nkAmHt0RRI7fD`}#IK#Zg3t(lQu`sv)*-j<_4*;iFSpM7pS7SZR z`d?#bV`DuHWCsHOJIo|Egp^Zy?5-|he1#oRg1 z0MGo_7Vu(DhOoZ?|&pI667IxVpLf`3D3B z1&4&j{`2---1`sl$tj;w)6z3OfB82z?`uA`ps=X2s=B7OuD+r1duP`V9KO4!cW8Lz z*XY>z?}@oTe^AAG7|qW&a<$c$jvbV$J~@@ISj) zP6aU!Rvxy~;&SY~+RuUZ(2LjPqt9HrpOjP4aTcWTVuSCs?@x})V8yxXoBx^iKP~(J zXIRYtm1Y0iu>aSt2>>T63v=>Vc>tP#;}mD|L2Cn5ze-{geOyA_#(3hd`-}LgEe`NbcBIFEc5&FA7 z0IUnz2b>pQyw_I`+>Cmpbd9H;Sif^gzSG$D6>&0#eWs;5h?Mw4Dt$M(~j zYEgbMluK#9MsEXuH` zazU^g7>J|>;BI*dlJ5J_3$FjY|3LnDfl2^A_LE|Wab+b~UcJ7k@y3cUv(=3m<`ep2 zGe$ooAssGqn#$dXUo=Wzz#p6dFx%5#94GO{)GuGG1d@XEz2y;6l@~PCOkNWnd4$OO z&+Ho+SteJ5US~(l;MG-z^zUc}d1);)tS`|{0O*_PSEK=g6OcOp+a19oyTUc|_mULT zwspmf>wF=0Tr6TqGi;Hs3(UZN4Ux;}(8Cdr-T@Uk`4;v9{L^cA2j$P|7 zR&E5--glEr!|jiPU>*0HJ&zQ_Lo>g(idJu|>`CWa%z)Da5nG_gVA5Hx+#)y>-x87C z7dswQ*4K6*{k0@rvo_56kvTrjf02`2uOq{6O^@E1TBGXcZUdPbABtmetZ%BU$#=I3 zBOV#IaG_+OXp=Q^_Usy7ZHHwL)?w2s`7~DcZ&Q%Nsg8R}fwKuV#un0mL+gzjzbSc@Uob_?^uy7Al@ih`dA!_1jz+>A#(m z(!bSn#=_hs=v7N5L>P!2{TKPtZD#R$?Pi(;0(LlMwD5do_)fCK_!Zj?Gw+Xb?ftiR zY^M?kRIFt;uShT1`m!8QJn?^z;*`-Y*YE_M7X^zK#*<-^k_ zoIl$Kgc5q^|KSW}C9$qzHuS$JZ>$O>o=X#WpHh?f%?DT~u)m9QD{>zzZ+Jab5}|Mc z;PyTEb$5uSG8*68nK;3hucJcdiSRM1z`JL_IxHD?AGkYu+CDJDajjnMo;?A~N?S9{r#a%cd}JK@xSq8TXDhWLRMFbErhz#D z#5IJw$MtXDExojAy*)?Oj!ilNJUsyr+^8pj)1C(-y(fT8jh{ijxhkhIeT*^Edh8hd zBF;y*e_u|Qul2GL>4FRET1wK`CjJCq%V$V_?xGsAV7t&E@r_#(DRjgUUR>US3cFK( z_}KeEnurJJ*}xurKLK3W*iG8qg>McXZW`JiNW<{@?{+sYe;A1Z#(B_04KStQc;DCN z^O{Yn84lpfued$t1CBLnxN=nQ^V6Go*(0&68r8no>|>N!tu1u-4qe@wjM{|$Nw{la zG=SuCJ5~+&*mmGgFVh!!q4B}AzGt3}R-y>de`-=X%w9O%gjfbWxV>)J;< zxt(L;P8qtZ4PL9$JyvH zWm0fn65FIHK3|Tm7I*?s|2tf$_VHi!cd!FVBrYZj_PuNK)Y`mf06#xLB`jz_mbsSh zw<@ID%5 zRYTztE%JIWUXAkv;M%RaJzYF-6cFy%kk$6*Z}7UeXJdb_7Uif!lj+*F0@4BG=tLES zYt$O8BDO9^7)EfjjvNS7*FOPx+QKjDG_~Fgzdr#~9=kzn44_rvBSY4SisLX3%tr%W z%z>rt1%aqpr3wWHjocv3ve0ei8*Jo*7z&b6Bdatwe2yUbGbY|jBqiwR72TxXLw^&5 zD^VD9_X5^Em>hx0uwW!s6*f7HQ^**;@ZW?0{Rr6fw;L9O6TqBZJDg`7HR4OK_QWx; z5~Zy3y7th`b=>~5?&RRUABIhVi2-iyAmrX*SeU8c%!gOk+CG2r410;`Ulw8L5MOm2 z@@ihCT+4iOJcXA4M6V)#f%I|{{px_i>f98@rAw&^)tg>9^7M&!IX5=nvCk1;o9lk2 z-gk?8KoMgvlO=mZmpGa08?sA3C^^XAvheBJC6A^P9U5}`&a}@g0}Zo`o)1L@-L|mp z5gvN%oT*xFHQ;1_0yqT(Z5NGQsS2@DqkfS6-Kp2*wckUB9h>++nykMtbqHq}UNTU& z(-oCGf8cilSkC`^Oa9ALuc_wH?Hmrdt=_&tI|u&x$kfA-69AiU)0LS@bWY2(#g(ox z?fF0w{#ay~RXtcJQYiH00{yS`DZbPeqe8wzPiALV@}BufD)X5j;f~2+@3rh_nwviv zA^(mdZ<6uHUqZt za_mTS)s1VN<~i4QSy_mCniP%f-ld>h;+E#T(--Lx(%T@$Ir(>YCgPr?xJgUA^YrZ! zovUx;Kh}Xw^D)B?-xI%uzV|vgoF5zdbY6OyCv!hVD72pdcFn-Gz9gyd)6ih+RnIKr zgC>NbkuCK{tbg4VgP7jX{2^KWZt8FjUAvgqK<2`ow8@vfBz^uBBLQ`8sM!iS&a}EA z;ofI&scx_tM?Dh#nVd8}%XO^0_wSwCHy_jM0Dh=Dk? zSkeFNu%W^n)Nk@$z97@vH*%ZR5iV+!vp4YOQfa}PLA3C*uJ#a^G*2gI42-Dw%2sbn zroW`pDvK0Id>R%&r;$s%60R2&hwjlHA8s#Ad~K?TmmN>_Jpm|b6AMoORKHP#iBNwe z=FQ&Q>3;CmufR3HJiUI8u6h)dOc-;NSIm4~$0=Q_A|e|RC$zn4;mO0XrMpnk&K2Yv zEL=e(FIkFuq8IteZ){4w31md45R8D?Vr2Dbs_`vcS~~CA!-Ox7yfFf2_=Q-g(MdZJ z`G#_%Uw3-bw!9oQe7ZSgfpZ6u*p2aD{(t;+K7aZ?==(yhyi4c7O<^B-{cgy5_kjJ3 zRpd%V9LuxIU{Pzwn8;XoWcZ+D@8900Cae1^{ekm#j^SO$n)wDdjp`Ip-@p01jQM`e z7l?%rjD%dUkG;o7DbFUte{3b%r@mZIR!RBW)EF8StmdjHVy#{K&J^F%cbOHiyE_|- z8r{^FMnU4sGLtS@c1@IaI`~7y#i~x@*bVjDDT@->9fzX}sn}Ez2(9~irIG!3N;D4E z&1afz<|JQ|{IEl3RyU8xy` z@Dinkra=dOansse8JMMC{-_5#FcvX&0wAjS?dHf^3Z1zqE%ypQoe@?q*m~&J7qqOr zlhY~sl>E&vWwPd60<|vCmjwvf$yU~BlERgy-9iAZcLr1xNqoNeF^xRf?&gnkBA4jn zOHo$KAz0md#n^=8Ieu^kUUkM?n_ztb(g|RF*YCf8#x_@k&n|JO}89pLb?kq zr!WvsUv70-$#Xo_D_w!!+6>rLd$XHSlbOQ7ccmYtzR*42zN1XMG|WElY?}^(sJUz& zh93TeZY79Gf<>lLlqCi%OLg(i76PlT`Z_~rPk7{;Tw{Jsr5o71n`V_v;yTUJEAFh8 z+zC>Q+pIv3f)(1vQNMZvWgh;@vVKkGlz>d;-uIN$M)$GMo;5QQ)K&uyUH^-u64n@ZF?d zaT;H%?RGq!>0@H11_E4;9tV4LiX}L1G%BnaBW`1D+9~|Pj0~c(D0Xc!Zb~%WlHUJc zSMyTJD&K5YuNr=%apkpc_vzJfuute3b)1!I`Jztg>#!jOSIZEuM-%r4@%J(^V{f|l zkQeZr%EAO?F+1mP_~IXhg*4}yEqsQhec~kqJG-Hdt_RB)pP@AWcy-`vo0lqQA9DR7 zaX1_ARTFaSvN=4``lC3TXojrAh}?Q~52a$YqqGHI(L|9X>DBSl5*1`5Po>}LRMIEI zVTUctA)us(nQ%e#g-tJD>#N+DReHu1g| znA^h-?oxmq7nt!rRI8|jm|1W5avO5{Kp-o@-NpDT&uJb4sfG){BBl#xG^2hF+^|V{ zq90@Rq|p9;ukvYgmMzSERWb&D>0l+d2BA0o(8&YNntW-VhY=?F_41)dB1!K@61^W# zFBYY~iL_7>MV6@&IFlv%Ri>Z!AKSiu0(k3eCJ5IvUpwl4c^#HcCbsxl27FtW z%Y=PSe0%1!i3=;GFp^U>Pn3dL>)BG3$vJ;jeg>Qa7cS&ES&W?TO6)8=+*zF& zs!bW3#jcOk%Y3}P+gS4Et&Z>gtenamqx=w$NR3^rpeHROVhF2t0(e~JU{{{&iA>X! zV}O!K49%}9mrZC0U9sb}ET->yypV>iFjOo0k(&B-3ZV>eq%czfQ+57aS;hsR%o*Us zzjYtt@12Gs;;{S_IwL^-otMhinY|y{IiT+8&$A+aPZLH8ooz}x*f3=+*V4?};tIhc zQ<}s}fAxd6Yq~O+cm+d=)#5lC@crAU1WWDi`isk3+o^%%cjxfN48=Z!(!R^b;{+%? zfJ%#ODeA)Zb6S|THMK7tyNQAHdz0vD`$Cb5;lBk0pFm?&4q--wJn~ePmzf5C>ND2;ijdnOt8m0; z0_@I^ca`I1wyJwtVS!~t&4XL#KN*MH9Kgh-Th*^F%7@mTCukpvX80Ws#}A3}nCxp@ zuSl>@R(iVT1Xw_X`5KM!>FiOjK+lC_jD}nxCT{5U7L#v{`^{`D4dvFC_b18sh=IL01X5qBN7NIIsbekXRn5uM zZX9oeeG;w~r>Z%kP!4GTDC*DK+Z`{yyfJ=oByn^@nX^LR^4*!=RZ%KV??q53mqo^& zZrZZhR%9PXr$}JoRG@fc!~xqVs?JgbRz2^XQL?R9SgS`*YM68-4CxB1w>?ezp$!eF3K|d% z6qVEX2d44sMpa?w-vy)&bJkw0en;ZjqZ_x~N*=2Udrx>=fX&s$V97!~QV=8RaKL{x}AlO|p zb{1@%udIfkr;?z{Tz(@XwVOB@zBD75ZhSUl8+9>uodnyhec}*PTemCI0g|)%dAvZ$ zem>6A(++}PB*A0honY9+6%7*usnq}%kF4UD(%3a0c5aBXy5xBBxE$g&ILYS>WU_wN zj&nM1sl~_6h0plgtaWl)<7SyI+;v&8~2j_8WK$o*wTfeG5@uLRzY&2nNl1 zscX~S&BuK6e-Zl-02+c!oY%!Z)^|EyTzN;Xr_?xxf(Ny?U_pS$461=`vf+{3xWf`d zGj8tbbUFPurwX8N{b_w0{Bq~%XXh^H&H9nJ@fyF-b$9On>Da1_PH0E8z16xQ;{5ZS zXN$@Og&s}H_L4>V^lcmJd9>rC*vr@0g&vBqA0k98!maUK&zno~1~SQt@z1kxL*{ip z5W=M(CJDBkn>XUsuU>84@Vp_P_R*uDU&Q)?O)#*HwOpWm8491DPo6VV^8SQA(=AGV z)H2;gMgKklC>L`+pVxfqfGN{xq27FqTjCp35`fMu>@%0OrfsQ`{LfXvK$ljH7NZtR zj_l(h{vq(=vr`T4d{sHgOU9Ae56jne1e*|gRH~s=cw19Kv*UK`1orkN(dH{DG0fag zGeKTJ(0MexP&M<`7D9EX-W)L(h_=et6@A|GC(D~v&iQ+Cu2@J!$owwx1n>=;4Dj6V z_HCFD+M$OOd~nUDAr{)f?+CzMf@ItMcQjQcPbd2k$5=#YsOdk_3uS4)SMEn|1>)z% zrlY+D-%X(b;2nt^;IOEzCH#Dn=pWK=>N;YijwRWPbgHa3k4sAQI{N;=u&czmRWlQu z^&G-S2Wb4nohszf<)aHmWNBKyw0js2G9{QSq?SWKaEF<)r3V{6?X@H_gkO}=yo;rE zU|gTtXRsiR+;;J*osM_CT}BJ+sU;N1Cim1414>|Qz|j+Dlh*^(0L$~b{H?K(u%EH_$KQ8 zx&oSk8E$c3$RXNrAv6X}#a(2y6bXP2fj<+Ql!e8z9zPc&an4Sf%c# zpiv@ngAuB~!ZHiY46PQeS0)JzX7S-UlQTI7Tht_Z$VVfx6i^P>Ijib`WN?+*R%P#y z9%3tU_be{*Rc_gMQ^J)HLcuP@WBDd^5^!vvyNQjAh~WI5I}@*D zgLL}uZzxM(DD8Bq07>M_5eL`bWSf%A$DXr2aED}dE+(-#R&*zFB1SG=`SK&R3+<;R zFos+an`nHg<(C^g8dH^W{3LdHD6!08Ii%#M>7uvS7$*qfI>th8UeJYT#S6i8Wc6iWtWB4r@%9ivO zq9)Spq0gE4v@3Eom(4*iLY_-g1TtG!P(sTK3tJS*SVMVr;L|z`6?R49;8%zJ(mths zf9_2D{Pv5!>3&&X7FM&1L{=7yQk_R_IU|-_5+2vC%L-jgEG3x6kfTu|&_RDT;{a(t zXj8P(^v9o>QE;IwW934}-iP>vh27du=pON-k$?D7KPZp+;PO6_XWYvVKM5(Jog;Tc8 zM1|rf?>jHuy!q^(i*Q6jGWRRx4t_Q?>_gVR$?P|Oj^(7{G@2fc`&G?(+}AgYmAJkp zI~vP*oI3Dsbm2jfw$O|`UMMUYoHdJS^a-Xce&1E@kz#lMGD)&k&>4_^6G z@NTXF$^r!ikX7y`qXHyKpLlxx)L1R98<1zg$xp?#^#W%jmKT}^8cZD~1oUsu+6C-W ziNn;vek}hIj}D(yxWka3m2%$18Mn&LGrsraQ-10RBh-P}kQVGCZ=Wr`55Ii1bUfHQ z*pln<`DZB$g-H7gDE^r!>N+_inqKPnIOn%v{KIAO^L5@c zvp#2x*TRB!3g@d37n`6jre0oeL}4g91#@+zUY$9Dp(c@P#`CH11W<)&#%!rmrN6}J z)YCrR@ev3nFyVA$_2{!fFIro6iO$KJ_}>P{gqtL6(qXOD`1N&boo0m7)V#&=h!t5#_nIfA0L`Q7!m9 zUg($`2yhWzPW`|tP^l?NYrrY|hlQ{{tPBVo#y>Ixgw}L3XC)L^#^X7EN#J>e_R|o)L7>S+F4Hk zWs^gns?Fp}H#E~1k=60VE(;>fXg8a^y1vI)JBHMoHHig*uZ7>a6J`wMFjQN#=7#lY zSGv{4@h{uO;WZOkVXFb*upI*MeuRbx<6XsRSj!9_jTvB1w??&AZLG%Q`sSx{&eP2u zH50q3B&2cH&B-k5MpX>@=&KulL~g$fJI5+sOCJb4>fF>Amp8ig9nc!MnqvavAwZ*m z-b5ip@KjEL_R+k|K&`B{*xY-uAiAdZfy(?VP+2mwggJOs$2y6w8(Y64w7rf1y1RYp zRru=wbJuH1{i2Nva-sRSAi&O zKd=1T9&7)8)rJ?EV8UPPQM>&syWIq*2G(hj_+e!=K z`kNC7CKoiK4q-{LFRQJP)!Ik6Nd-v9j#-wpOG~UmCwX zvRR>idl(g^bxPp-baP^Ntc+5<8A*8uB77*@F&kzUAtP1s+#NdZmHo4*8wc5olWKkA z`Ao$DS*DGFnr{)kDIwDqh`aXvP|rz(Z2fBYKETh+F>^rom=x(uq9Ku}&?w*%JS^0u z@6N;>kn{Ez{fuj>^%MN`mNSGIJRGgUF@*e_;5dpaEuGad|KkOn)*5FOiERx&yz05) zQx=yWBGBsbF5K&E`otOQC?k)+UA|X18usVjJG-9}8Kx-qnKAwjDJY6dj~YMNR1-7y z{jw|W?mMj_@@Z3(Ep|DS50HPJEACV>Qd%m&WcFXDA8*d%(|vkzl4KKeoYHr8fha1m zH#bl|0?=M+^~R@r{hVGAq4U%=u$el;9?pr#@_C~&b~P@ow96xUFLf%V9*%-0>UuY7 zfeJuR9*>Nwu2)rszmx5@U5`y8k5}7^Qy0>>CI=wqp@Q5+G>ngNUaVca1(Of*B5YQr{1|nxJVo2Qfsr3gNMXR&*=fKQC?R-#FjgClZ1@)_#-8c&rfG5C}k>vgWx}jTkz8>aP73st^o^ zF3q%R;_`gH{m7(+*Ar~9Ah+Wg`&&>A-atF{SFG|c) z5u-Lg0%K&2SP||J&$|9rXp)Jn^t;JqpOAS@X4&ZdTI#=V)7Q!1q{Lv+i~$v83n)Kf zs(oH52)0Y38m`J#f|@?$rtG^9RPWdY1rb6Kl&@wNx{2w+Cx9m87;@~;EMes1MHqJ< zic+`02xCIJmSu}{AQ+(lxBJ=uy}OXt^L1DFS7^^-zT5q4a^iugx8U%Zp&WhLQbrF! zT^aw_qElg&qwVHVPK2_=A%nHl^zdS@vh((-Rb(I2r3WuDOAGlu;?AlLwwZ~+0s{N_ zxRp$!wIkH(f9i##b;AbdN&2mN@}idNAybf$s`uyvAKEbDEKbfA2J7IXxT)RSym7t+ z^rS^z`|sH&<))>1g*%82T#l)X@)LPS0#STx#y){-EFFv#q$e-f@jOHp#l7 zF=FI!b%vGEhzXRW>8b6|Z#GS6C}Fp-H<`gqUGs2c^eLvXl~(ezBU^smb{qb(y=}wl zOlx1z#v-|IYHKD}^qE=;Jm8Q#+DZ>D&(hd?6wc&APIspFR2{rp(=fxBJRl*{1~&lA z**8l(_-PWP!FD}$s*WedCv+QWOZ<}3-&eIJy^fC4XgC}((ztcclT>0P0r!Dj3!$z* ztUP|lHzg>~_uAa{J?rUgN&t$b`BsSdT z!!AlCE*n7=9(nctv@VwGRk=ySYsZ}ZjtRHp^p&#+y8MK%B2!d++B{-L68T^Zij#{` zzsJeZ;8IxHzcTH2`=C$c3H$ZUSUk{MnJq$A7Fc)ma55L;D|dc-3l$q1i}GLHrIfxA zaTQ7h8HTCrdF)n4JACv=s9&Fuy@Z%+oZaGLAq3LL?cO|loM^-J*|`&Z8NP+}u>`H7 zJov8Y&1Gqg;fKbd1B-GFH*EcQM2rGKU{tNVp|YA?Wwp`Scl#?VlgLjwvC2-LW__`w zKy{<=O~b&@oT2ol9`z3kp6Q-7rb_w^`eF+{dH|8OFp+}+r{$kGao?XY0}T(~_9<&l zlcQtD;X9NTcNtHg2REJSx6Ds_ymNboY3AZ4T$0=Ks=0J|#bduoHcr6lpSK=i=oUO! zmWueK{q>rZ4L(9=T=$xLN=_@JW&7pW_YEb-fY5#$~m1aJB47d-uya1GgGbX^#y zuq`gEK#bJ!8^ear|NZBQw;fwxQKah`e3qnm><*Cs+%*d3`EdTPjT69!`;2ver!G5r z4^=ZQE@0t=Iy5?V__V=wKlGQyq2s5S+~VGsg_|7S3TR!glbvUY3M5}|5Y|&;iE;b| zLRc+%sG5sz{ zU@+i!kGP(gALb8WlMpr#p?Z7Arn=DTZ65p29D-%X%Z$&9^K`gknZwNd{F(u|OcoqDmwO6tBE8>G+1Xo^5wcQw1mDoC7_h_lxVv zw>7Z;UyfSyP#(O)etgo)TRZ;ZU(bvDJStB{bJS^RK_I*8JwkAn$MBb3{a@p~_5SU9 zvt`YfprAA;yVl7C>_Z_xF0^19{h&N~E*O`ekf`)?;%sdJi znMxr1>)!~w^_tXwWzdCxiQaPpxK^T+Kg#^bO8SJ}Dbtz|5ur#gqrt_lLN6ANxy&p$ zjF~dLXvR#O4)U5P9o+C^s5j(l3E%ylBoDq`j`!j=V$_Mw55wPndD`=Ko=%D#n^bp4 z4IXFJDOs3%;?`d791iEDI@S3$2QIDZ-rD2xn%^LR14JBQH|HZAYOawKQhQhJH0-Nm zAY?JLD(vlZgiW$MVoFvIuX!B$`6t%1f8~q-aFS};`@sO=LRzZ%R4#NB#?-Q%0IV@? zlDS-hK)^+!rH6yRZvVvOdo@I!2xXIDdahF-sg-(~EFh>8YHI#yIQW5l4(T^T#BZ#Q zx^ByS8y!4032j~9axALn66ythbZ8Pcx_9|DIQ2hQs!+t6cFJ1YW=kvz{`Q4V&(PPzW^{GGgDASL) z@KW>x7qTld;%;TbLj0C*@VXAx)$7);q)d;7SAZ2D2p)R4#ZoBNbme)~qnb||!_oWt z`@4rfg%a*o7RqGb)}|}iTye|#qUhL)Dg@buP;cv#7W|UC$O~c=2p-@dymfQRZFJk9 zCe3skNf3y17(R#2iYv1xK4^j7_kDv94Y=%8{_nPnh zWU1V+<0Q{_uEMnAy2A;;q&Z?x`*fRzZp6Uzi1<0fKiOR?dI`}^@1@XwD0CDc&xOGb zXj9yOJ`u4u^~++!OJLq7bTfKbqV#%xzC)#RtHmdZkGEvP5~5zPo&fD8jEZp`^L#1^ z*@AnS)r>frAatg@a_CE@nZcmKB?76Czhh~Pe;^~kZ1@)W4ONHhr(T?lY$d>O=Dttf zF_+&&QRcyIkFTCH(x(=LBrRmy0$}1}!05l{ANe~=i~gQ;mHVTmwe=klSh>Z|rS*c_ zi<{2?o2$+L_-5<7ttXWSw%vw)Oxu{l9397HZ8E)$EBz(=DGW@eP!omqX0aOml;~OIGrx5d5n8eDdV1&Jp zuZ+~B`Pr!V4c0rLq_^tI9c&0?0rVCDGKIW}Tir@|hgYvO-&q!HHL~n~bGNcFjuBG( zH|w(!4&>Cy|4YAzb#ay@l>i|LL$uf8)1-=iXZ92lN+F?Gnmv z7vHB^LEZR3$z<;g{6X!iS`i$jmCLd ziyS|$EQkk_+agtdg_!0dWtvi>mxEXJEKlLf*G$YP>MD^C><8J7F)#Me04AJ%9Dm)> z?JUhS&H+syPUgKQYn=uZ-WM zZe*Ao7!4{)ctvIzux-F{O831Gt}tiol3 zI+x-x(`4DI)$Ri9{>WM|OP`6Nb5G$HY%{%ue5B`$4sO@fXU-XQDt%uFm@NMma>nwE zmc%3qV{Nv>xa$jw#EiyKbW&6w?sH7L9L}<9gEu~6@WMMkj&KiB5vBZm$V1`wFv?r!nC{UF7N4qkqU7+@^|7xmQwGrM%@vgL{yqH2)Pd1HW^MnGVjZ{X`~|h=vEW3kY2RT91rHhCQ7e z=55+e!l?V9E+#b9gYC|bGd51IzPf+@yIx16E)7pz*hTi7A$9z1$G;XHp1SK+6jftV@$v7k%M}z0CR@Q?OqZ5B5xkNs@k zw?l-r=IuBx&$#|0<54M3I^S_0b6c@}vLOoJ??=owCKy1b_dnFhC|^vsr9Q7J5>|0t zbe&K5pu-l8d0O5xOTJ!(@I#p%-+Pkyy;83ocX83_bw&AXR{Np#B+f2cHa^Z8on^ea4Ip>R}#i?SJV~BA5pMQ-*b3-BiD}$3y z%-Xg;PXw6R@tEnrqZ&F-$?B}Ngz{K^F>O{_<5l8GzyqMsV_5Rxub>Yp5@- zpj+taw&95!g+<>`M@)LqbL8VoVTgr)Tdg}RAkQ@ap{PqKf8yp)EqPfcfcIb`1N1r1^4f*xI2Ep7O$vRJCJD2j{Y6HL13ocpRu^Rb5nm2VcJp9|Ky9MNWl7P{EMaO_l za&ZLZnW5flHP7gK(o^UIFQX%C2ceWNSg?M9&JAwK891LG(u0${i65J|W&GSVeP7@; zln(SRpyc~8{Ni(5iWD#Nes602G55eljqJ51tE@ryv%+Q8{QUIwXL6VGyScf#WD?E6 zFC7>nkrWMgD1#2QZZlExbsL3n?ahHEr4=_GJwrL-#A_kWL(JgIjtl&r*FNi}FWw}~ zr8%qW_8a}tef4Yq4*kNny(h%?XtnK6YmBahv6xG-0iqJMK3DQ}jd5Cc%`~XayrJgD z;zwj`Ipv#Mvgoa%`z5nU?K{FtAZz5DSLm1d3PNn4asCQC;|^sY_l98m&8%3hjTkAQv|6}@^% zlend+P@_$;))%t>`esO;_@qkdA0`J8h?_WW8|((ku%?RmgF- z^z4l-&c6EN_M?g@U6AO(`W9W`_BWmd67^aUdTCpvJ!yoceiuKz;`3$@UiLM90S@Xn zI2T0)N%2oHoQZ0lkR1xrc(UxS*YsZ^MOs@%Qw}7T%rBw}19k7Lo8G$Ze7?pJCJWFV zc!w9DK&S*J0zwMZ)O=EEJhDteTK44zX*@;MJ{3?~}V3ZL z@Bt=s;amz{jf=#j@3N@qyi&p3N;!jM+XG*|=uwhP5NCLM;al-~y442_57-%Q*~Ui-Y!XLt1rWTt9KN8`^Wo&QuC&zgLedj=FYx{XMYJRCaL_T80(~$yh>sO)wE(D9=m?? zcE0t`M{1K#6f5HFw(8ZQ@)8xsef_BDFaepl_secFKTNS&tBNieIHgr5$}p=4y6c6t z)g$4LQiadtc;7+*7YQ2iXXv_LxxIRP3x%L-cE4}003ZD27H$Cw-`rkuw5@o#U;D6D z<&8DikkY4oXP#ZqUIGF@Fw3R7&s0YLdC+GFn5sWzvfoj1PfcDTr4jID58Wws>dx?FkEf& z)u+FMX@)NDtY2VkrUkqS2?3;$N)$$fLPfU++$Y?h_+863bKN5cO|!-l%UIDq>f08* zDTe(+e7V?{8f3jS3g+;`#6)YYt5ttdUNU;UGqKrCk5EM&L`B< ziNY$}W7k|h^@2SoAIJ7^X#Iu4Lq~moj{kIj@1renJfls7IT!@%r?}!4$q=DE@(xq9 zT$q<|KfK=A7!K!jkr<}`i59Yt+=T60@h{B`{d29#)Amp2O|~Q|y1Lcn_pd8saF^t> z1JYBB{RR8uwT2GjFnuw5i$=>74Cdci)$+Thw@KN3_!_u@P>II3)mbJRe2ZeNehSct z$~)@UM(V^pms-tEjZ~FO4=y~1T)Rt!uK=C{`x_?FR~GS8Og5c~a@9;c z&!(B0NfmW|i}V;t+~^zwfgl7;?gi7o*OZMiIpHOf?U3bjyhgCj$x?@p;MSBr@wctN z%wEenPBKD|p#+ZQeu1AMbD=`JxKPt`A4hAQKHAeFwU{J1!?eWeugmp_&KJXi{_pra zx5Z{s_D$@d_{h$rB}C2M82!Q@(h^1>JWOQKW?EF2L!W$nn`Q7=i^=AZ!6}<+`yRJ7 z1E*Mk{3I?Tza@;}^zI&v&U0hMCUUU#wuU2|I2Y85V)?OvRb$zqyBFQ-ak%+NG#|Rv z_R6S6O1BzzIqc7uuAPp*3O~f*DP!JVZ{UlRuhUENqMLp zb};oIO7rTG+XsJIWX+~1{R)>lh9z87qW^2SRwpZVfR#%}nemyJbG4Iv{{tTjlkj@k=X|FWMwRbpuBNnb^R^aLD*P!|b?AwNQ1(aXpjDUyu57cDSX z+ArBq(-xO+)#r-&d|As%DAdma?cr-xglvg!()7oau|cB=A!{loD{5cQExh$uhz7&` zS%v*ZFx#n;nK*1x;JL5fSptz6YD{M9tw*p~xb@Sc$mcN}-@;&r#n;AZ9@*02?dP){ zDLDX_HZpOH?{}~^d1zI`v(^7&Q#2`QeV1OGR;6#ggbhp9b@B`o>=AR#_c9bJz+3MY zEshFGr$Vlqsm*TiFx6$VPz-VS#&*ZmgzOTI_BBfM;qD=9w^+l5S1?OQAMEzc;dzaN z6rT>euQpJEDgU$n==ZONGc4h5u@ditFwTtR8*;rrk>GovG(s>L&q=@z<81OJB>` z15+i`BQ(D^O43z?>c5%%VJ@Lg!Nc^Z+;?x$@Qx?7LCBY^ltF3(Kt0G8P^_H5gcDuG2Tca!Fk5>~7XE~nLb$m{F zEq(iy>Bt+ayTjd<`bw4aF+q{6>Z&d%)-^poLTZq338LTeiO(8#UG6!2-jkIHQCcC? z4^AUawe_xKWN5OFMK?tZ5sP9S)95_)uX*^MbDBECWm;TUBdd{h(63#XmecJoyL+{1 zocQH8*;K*XS^V=?k`l^aV3kgrO+M@9ne)ai<&ACwGR>mTUJD3LWqZr2{sCVm=6!$p zU}+F~fPj@zOjDjF1)Sd6__^MF`U9v{S%^NqCd(HAkS+3e@h`BMIzLo%*1k1G#7@Vy z1-xhwmsE%1&QO9d{w|>d$(VhKb8xhldhn9P(9**1(coXXQyHD#d=%DHdGNKWG42zD zX{(#IO{o19>9f?rW@AgPq(IWQOZ564QV5y8{Wx1s=@VsbD@1LOuX*)1%M(YQ4jj;W z7cWtV^%xS^I`GT|JD|g!^ThqOu%f{CAw-Gj>16#&G5>0GqDCqV%dQr>unJ6QuK{WM z`3joTs)rkTY5~mJM)kF^sHfvu(I&L z$xt8rCfZe0d~AW%hz^GE$pHIrrP};n{vQ=wbzD!=@sP?rtPTPr8Q0fC)(FDBt`3yU*S3?z!hY@r?(36Bq6qmma+m z3Lx%{wC^w9kGHwdXtG7eF1l|5-9%yP)Bo1^6a0E*pm>^Cz`kQ?_!a)hpE+(QK`Hsmw@xVxtA$Q#t2wTC?&m5d?Y4NE)SfczsfFHl<&mVDFa#&fIA zUvHY594mW>bql^7oCLrin&I*uvl2Zl33ZxFuu#KQOKS_i6I`3J&HW|^hDh6LV&uxv z4vzHfnEB$axUI*z)@dOv{QZ(2{{piwc5GqXr>J0;rGFBvOR)+&f#^x7NI_(6a%q%c z)1rj^yuzdQoh{2KzR#j_GB_QRZU7V4uR4CayMp$?EuZ+6~VvTx%^lYY?^Bs(q>)LBJB7nZiEDfy{Vd>>l3^>auhtTUHMeIL?u+B5i7qSaeK0=pJ*m(6p;Z^MxlOc#uu!Ti z?}JhCSnF)`$gaxTr50DLxPnEXnu@iV~^zo)56gqLR$S)^}dEotKj*a z*ohbAi8jm7X(GUJA=H=tyo5S?VyP}1u3ni?hR-O37)hkQ^3WA_f9q^q>1;T}W0w9% zEG1?@zR?MQ{p8PrUF~#lGu$F-^3U?JLKn;bk%hqxWH$AGr&BE}wsh(LY;cVGf|lTp zYp{ktAVn8=*h7o$oH-rc+?CwYy{G$FasSOwphA6yg>cEgi#Io%=YK0@+j;RkAFaM_ zapOP&+SksB8Ep`_@Oa!9XUq~*>mD||;8w05OB^E^U(6(M~8<8~nWO$vFTK2|bJvAG}kh2nnYyZeTx~eW0N#X~`IQpu;en{cay>=-| z-bZ=J&us!Ov+kqK>iKOF_m3|+%CcMnkk)>B3-e!pi6{iM2 zbvq=)xbh!oS#b2M3-g(ha9KNeuO8-jTFrrXH%*Uv!lN6f7f}z!U7q>`hO5xrWAiR% zG__A^lHr+KST~!^dSqob=RP{}kifaK4!AkYHVTSd|B*QwV+Yz`oIk?$5sGw{#}_X0 zG2V@Tb7!;z?$ic*+i;v6Y+ec_W|^m8bj3ACE4RIllltiEvKVa-q#1#kEL%Jq8uk>w z5@Lyr+^A(kf%I_JRN9KV+le<=H)A8)IiwJ$-CMWRp(?NTTcKkNufN*8z5eMdGC!f| zj$YtRzfXsBe_ew|@D4_yp8YJbg!&T#P;ps36U-GcMBj)qKvp{hE6uQmIL`AMrVNN} z4hNnQ!NVoOjlHqV{rHkB0*LaOT-eUp#4X39KC|oYnVR~;fvB8kytJ#WRvn=|nisAT zP0-sc4^7j8`ls2Y#)iE%0j+Q~;*sA1yNvBegBE2l5t*y6is!+%`FzB&d&*y-<6b=F zi_z1s-@@C?RAp_=1Mv$omFJZ5w?LCO4!m{Uo#O|~d8QW{9*9_)bUI)}ek00XlCvH& z8a&M3ihXx-$Tdb9wXsompw(tX#)~1)Q?-v*TRYynThR+?*1Won$+yKVT?L*SZXr{o zdYjC2z3QbPV935}_K3*rnwcGaFL$PJqE@eZi2A&&n^!nUIjOPkJ0Fp7n(CcM$I^H8 zu~2{2*S<$f>rdKGbNtg?5zp%^Hlv@p0?u28AzYAwjIanvGl!e=`(j8y2oKFzEp6?!V^sS6*8a{0-mXTa(55uW(a|CJ>orgLL~)lmWc1%7d9U(;riaAspF~Yn@xYXWBKxihFpc-FT1+~sX3Cizy`!|lPA~2E=sUE{KR=hj`@~$Zn0~EOTk6=OnaK|$r2MM8{5q^V% z%B(;^1TG2wzh(N32mf)-)R&7^r;kpvWD4*f=IANb3!xQ@4+%Q(<|B4^ds|N{U=0bE z(iyp$OdVLQdNOZ}!d?%kRp_r?X42XD z6xL?fVuvJ=DN*kb^wI%+zHAE|p%Q_MxB6g#ZI8sYrF3hv}<|2gB_93#q}) z+(2^ZW!Zp@O}6wjxfh?q1HVnX8{f|26cr`oAb~%x&@{Mz>FKK$M>WWTE`nXMN{;n% zob|)Ch;0MN5cJCEv5B7p%)<5=MEp;6qFPzQXR%PJqdtWQT-(^D&IokgoaM?$EuLMJ zD9aOiW2XUPjoH+E^z{xSG;LTc0a25agNKennqLG@N5hZJ=bcehNtPuC!T=mmf{V|yP zc7QBVVbZNi@-aV^oncpWsd#h5_|YQ8FM35*+TWUMlt#S7mKn5!goU1=wT8w@lS^};BTMm@$_j~!jUjLy zOYtY3meS&G_3G0Wt37w_1S{<+Htqs;sahz91X|LLZ=TVn%TV#BR+WGmZ2-k<%hy>1 z6g3P*C|_%`hPzk+hM#Wn=LOF_8?OFPuI!DWa%x@JG99zGk_|68m6M|S6o|aUuQL1# z6zn%g%j>4KwO6Zl86xMZ?P=a82ArIc30_WPUk+Z9a8@xHe3dhiiK@3BQK;8H-98IL z$jF<9s~1qj-kc4$8qB%Ox^^84F~_N09&gRqqym7Rp!2(BOr*l-{Pa2nBTf|^-HNvXEeNL~0D{{5@{ zqR44W92?{9^Ip8|6weyF=Z%YhWVsFRDFAC1dx zV=EM9H=qh$^MhlxlrbIcVCj%H%PY86PT>ZZbFxvE#8X!6e4+PgCBt~>6GClTp)|eV z)wW7RyW%Cdr*M978w-OM+h%c|9$&Sh|wxq>=#DjuySd@Wl4Mg1vk>IM5ni$5v4s-&-nvFiZgNtwfADLVJ9gQx{A2#Ic z39R6;qY}(+T1)y_F|B%teBdvRsqJh10XzCzU1sxVzwY_%37lZf*JMEqH)X&Oa}wz4 zmb1s&RSy$}OLR{OjZ!DW0wHsfYh`Lm_8ztIT{=)xRJ*IMPcoT54@ zWVkj;b!E=%7`;OF>%7MB7ik~+@*3RLt`S+!LTvO6@a9YD=BlLf>;XDmx#Jy^Wn=l^ z8y&JsRmQ^2)p3hW0qLyVMem53jDr_>kjqdWkE?(7YA4Vr4c{tZauku*No!%;GelLW4(Mx zJy_Z+VtV(#wNWN^>03=}4>Yj=36xQTu^J+C`l$K9`37iJA+dgQY_1H0q1F3^%gj_! zHig@%OQ$v@CsUi$8{#0_g<)8hk}|T{;w}HVI)>}@SUhFJ;-Zrb>(MNo<&-RM99L9x z{fgtvyHu-^{Z%f`>0W@TG9TGZ!8}U9*g#gR6G35jnZpOQ5iOn)^}C+psxL~%9_ft@ z`W9QySv9M^8EBZ!J_QFvVU!W`z!80$vTnKcx(~s|U$!1SFwR$xG;GXRYHrB0d(`eH zxJ{}J4_r(H(}YSyd(5WM1U30CLalQY%fWP;X8jzxh0lZ;(knQM`N*&C$k$4Y2g5Jl z(M$$a$4gDEYE>~A_Q=t|wy!JD=@t@8}z_qxn^4k{-vxN*E=4=N#^HW7reM1Uu$cw~=cE|ij8O?$wQPTWkoyQ$d>v^4> z?jV5#C=KQeT&t=49?x}fI|1^8`%z4FoUz#qs?vK*!?fW0eyNSVPll4Q)|X;><{C}q zgEI0iuHojykK3}Gz_c#Xg`jEAV4nZUJ4>e$@%O5<-N_a*9CBU|T5Bb=7!IW1rS0jp z7Ss{)rQJr)|$h%3wj4LxZ6Bi zV{K;sz4b6BDS3d7oKmdPH)>UXU7LX}x{G(GV0GKmrE2o<04Rb&+uebY0Afb@quHXV zJpXGA0t1<_4E@^9mOg`{&>p%`9ZWV#DCR!V~8F z%06JeAhVz19wHF_w7=q=@T-SUytb&kz1k9WAskJJhV_h`+xQ_N$xIXPpa%FD0v#7< zjcu@=onuqCm`yk&Mdd076h4u1h>;icYM}OGC$R3Il4JQ3v|yiRn%FDaSv!;_Y<0ZV zN_br(<@ZkN{5a~5mW!r~@W^7Rmya-_zQ%Fipt&9MC;|uQxj5{vs!@hReJy?IRW zsluH=JC2HBknLQFS$)F|1;^(L;$*&w#a(yQGzb*vFS~yBYdH;{UuO=Vzd)|pN1pe$ zfH3EZxIK1atTRjL2hF+$#_sA}D|H$K$}-L14v|Y><ocS0z{ z9sl0(<0EaHh{OD~JT~raw=CecaIRc+iR$v8erYEMj%rABSUH?){MlcO6k zKvA)^1&q@caO=C-U)cYZ;9X{*uA7D<49gpbY>lv2tDsDPS^?nE6Z@H*7hDC^Fs@LQ zMr&!n;1f@+Nj%K$|J~W(j1gBo5MiduM1MIW99|Tu{3LgOn!#Xv(@TKtwV3KVuBe?V z+}A=S8`Oh$nL#>cwFUh$>!{F=3WyJ6*LwnrfZR1wrm(s{HyE488Xmh1DwXY_V~kBn zk^Exu0wt2d@=Vh{Nk5U zf{f1%52V#9`?!<5Ydxo@B{rlMU;BbHX>L|***x{9s^g)0+(q>yoz>uZ8AYBr;Gw5C zu5TTRYBpgy3F>)(bhQ_vy^amWBXUp|=lb*{tjF)0x0n>RaZ`I+zq1u9yLpmrS>6nK zMFKOF7Ri6n&R6NrF@C0-_Yl1XVmVwVjnr?~axkFmH+$ddDJh3HP_zaIJS%miP(?Yw zSPqOpU;+8$Pn&Zb0e@V%<0?{SGj(Oi11`Wez>qF$XF}8q1XF@>Kpw9eubu`=AJzag zFc%x0F7ngE!%#`PHdfSlem`~_OSO9BrZGPw+rj!tWh+bAxD&>9vVbHQS~@@IuX6_u?n$n0Q{p*k_1Iku2II^v^P8${ z&dufP2~NBX?5lOJ$-{OKg4*x@{75kJt+MmlZ2!MKMy`xAK8ALF;F5q9!h*&Qu4=Gi5);^%WutD@=Idf`M{R;S6W(*J5?aFu${@hmf%P&@0 z`-H6y8(szxYJcotQkI<Ru>K(T_ z$m5J1s|cpUjzyE+SoyhY3vTJhZhJh3Hz81$-NEueU0MLtr6apLdnwi9$`5pzFKNAS zT)3$}&aWnEI`2jnsh<{3HJmlj%xrZmt{A3Q&cpONA1(!#2Xgh~(qm5GDwcO;OvB0< zZZAl@e(!|}245N-&|NS+91;QB#_2!9@f@pp4nYEKiO(;#cwt_Z-BD{5#`AoIs$I?wf@F3`!nqEGV5o4#E02XSH3Pc z9dXHa^2nW71cZ2060qc4CoVDyYbPLw(;1sv9r09n_LbOkKvlD;IsQu#k*7J5mpqfY z+Cxr>$g;QNDZnkQhrcX}V?McYpAO>y|1S5}sHENa>S|Tz^|{pq?_0F<7G+CYMoc5&QofPv;n>v` zE^E}nW#AvDkqom1_rLL;z6ku`8SU@-Lhl*ogiQV&UsfYx*)m`X%g)i4U^Vdjl~-$F z**bkGD$nkO0Cyg-Mo*5IUkSv>tKY2r{^k5b!Z2(dt9$ACvU-iTib}R`1t`A(VN;1# zBMaQ(j- z)U)YMtD~WsTL*MP=4hAxV<8whw^cCCu>*Ce0JZ=Ck}^6Fo>Ob8lv`>a8P4ZSWH_5cfy*9NmsHP)5NzT@u@T6SHsV-~b=}A-?1{-nCra$% z?3G;Pxk#QvUPhgcvCM+^O1zphtnbPOlV9Z2NvVaF9zcAxQ&c|GDUq5Ft3RP!N!T>)hkPr3Q|lxsITwZY9V(pe534>RBmAPnMvqsIPT0mu`rF0E z2YSM>K<8otE$J>S8yx&W`-8LX((UN4fIXX8xmUGMFQJaaWp0qvpszBfWz=po|M8Ih zwP$c)08}p+CuWK;i!z*wj&m1f*#@MzILcR*^wW4Oe}iKErrujIv+(WNAt2_h8)~vu zH;`)2xc2hx;&OwVLf*8t$e&$051ki8Bd{R;5;926gjr*2+`$P&vNX?6M=DVLg`0y10nZEhnval))r4w64!erH=##!R^j19+>9M z3bu5UCuA&?hg_nvgavlLT(%rz1P&L zLNFD#8-flJ7%kt@?|b?80Jm==q|oqUmzTk> z_h!v3nh$!JR>@~$`g*+m39MUIVy6io=z8bhI2`z?OS~AGawe2nz8{zJJ1-UOR~tD9 zted(xDV_9~!7e{Yl(Hc^ZSqwH;D{=Rb`I}XNBm{(U(Vq(cFR7M66(9eZo}Tfg#+4m zZh`b-f8bm$hrx3$*uk-Bf>Z0kn9EWKw3!o*W4zw-mKW)=E&qv>iGbJ`A_or2lw*#i zP8)cIS1+FJVJsu>TgIA(P+;Kuf6%xDrY|ssRK{AC>n)=Uzf#YraJv>yT4^kCTAB; zH#EhSY{9h)LB|;Y@O+F3f$5E`(-?=>@O#8Und18UTmgxJEnbVybwZ+ZeF|O! zfAUq3uO_ooq)#ruQ+7#F$Q)e~k2XnTYgXNZ$gD9RgTm1MISOyFj^oYV4#tOrtf_42 zpVk*HGS2wGi(k+jC`e|foM2otvCEj>Z}q|PrHeNFl$VI*g$^yv-*csn>XPF3C{<6t z^8c|)-;ioZeEIxLCc^y{`##9{&kdYc2WR@Q+=MJ+E_CH#PbcU}NJLrLL0JLw)eXMA zg81!GNj^7k>K1$0Fx^G<2528+4)5T(k{FN>TEF5#zNHGA{5!|0=pTPRO8f*9Ar8b* z`iY4QbyCKz?8*}Ny7}_ZHK9wN6*YiUM;vV}kGi@J%(V9U6|&PQ3LSQ@d482GC3awKEz zHnNDko=7nT^uCW;^J;*#lnXL^EWL1QeWwLHc8eZP)m}mcd@32F70>aTch@c0|3HMU zoHzhpF<3)-!FX0usbon3?VP`VW(qhOv0q@EaV+5e*0%ru)RQ1J2e?FcE2W+q`J=zE z>BSfR03)?@+_tL<{zT7o5ahKeohHqsl-i%ysfAWwc@T-XTUUsY;h{WI zgyT35E7jMMh9_*}Au_}8pRMCIrw?u7%-9C^PQr~ehF8K5>0t{eA!uTYDKHi@H9AUo zkK(7D^ZED5tuN@Z6gt@SA!jqKKCMrXnq4BRRPxUVgv1p*S4O)T2#EG%>XuXZ7+}@( zCNm~ppzl`0CDCaj=~CUCSYL>Hp;9$aUW7k#p}6$ua%Jr8xu@lOQN?{_LohqnNFGK; z*#Xt*g5LtD4FIe@^Eqr*JFS7{fPsT2K;jI4T%7PFX;~pwV*ktNwIrOUy<*kxY$xDe zx|L63ke}~|vcK-kRnwIjhVfhOE%0#2GN}Dv(RzC-=1$X*=}jyE9bQ#Sl8~z!gcW)+ z8UCz9MaWTE8_?m0^SxDid!WR7ii8O)zTIU6aW$ecNJw250Cc4^+jSEcP*T`aq4lG{ zSR5qH5f|SbT|PekwH?TCAbmZbw=o2?);^Nf(mHXAfo}ZBVSZtiWqpj7ZnE_w@CZ&? zzd~j?Z@WFLOZNYegm{@%)PB?fMiXrGE#V#aFYY7HargZ{rzS+$UdPH_k4k6eo9k-aNeLqNv{M`IT9RDCLH3S z`*jo*55GilVe77_qD$P8ELTm)O^WUs-(^82*Ru!Zlz5L$ST3i0NG+*7vq>-KeXS<5 z0Z*-w_N6nxaAvDA7?0wA%HO+{vc)4kp7A1eQ2#)zhZfe&peJsc>3s*x>GBpjP{tmy zrOK%IU&7&9j~T>9%PqwX99h4EfT=QaKfmWp;W57z?VPlYLVxs9Wz+ys!(G+mcgVBtU3YNL8(P#m(#W*Cy{04R)4u7(v z%3w(WqZq4`9CxH>1FnXArKL;w^^Yu0esDU6r}RIT@6nbWqsu#h7YovWDU{%948DFg zqAE#~V$pl|kRLf)i_aV(Pcv{X*9U)Fz9ijux@gC5gZ&gldSS5H7j|2oNAxaEAqgI|L7q;Oml2eoeprD`tAkPi(v;koF=;vSy z0H~@0*Z=?k=JOmO0QGr<^4$KXJS_oa0WVNc{(JtHprWJxCoj>^P|-17VqpB&F|n|* zFfm_YVqm<&ef0_(=UFhYUgP27y#84UXY%w5K!AxN0vJL? zVF0`!KtUxydFltyKI@A9jOag5|4UF_prWC_d{*fd_VWTL{ULEv2 z4?rV8C#2_y>*(s~8yFf{S=-p!**iFTefIY8^#l2bM?^+N$Hc~g)4rxdGBUHW z3yX?NO3QwfS2Q#>L7Q7z+uD2k`UeJwhDS!Hre|j7<`)*>8=G6(JG*=P2Z!eum;bJ= zZxFY4|KUObp#DFw{x`D!4=#ddTrZx_01e|mTqrMmpF1i68ah4iOG0T)3=0n;2EI_t zH!`W;>wB;m`L)i7Ej=e+kpKlYn9l!$_CJyR{{t5Ge}(LS1N(n*!2vj^D9?w7N&t`o zJV>517uA>};ExJ|)kdp>tEYbtY%C{hgzB`mTmC>$F)$KZFl6PXz0oL3&sk!C*1i^x z;0i2P-h|HpcL{ZoeVhC_;d9CF_ahXtAQ8GJz)Ne_XlhZbZXnS-NS$03h@ z(gU>G*@_WeKYs{nWNfg8@V64Vn;U!Go;I=2m#j9+#T~d&?a1oXR12Z1B>9bk2JTu| zJk6aQuBt6;B?@tL><$n!()e|?g=RGkbuwNWtgG1>f~6u4FZ!!0Z+z}?^y2Ye7vIZy z>1Df?P=DoZhE}RJsQ%s6s9FSW(S6=S6!dB`KUH z@#fNmyAWZBAMd|AQ1u(#t0Y>UzUE=c9VIwCnQY2#I@A_}u>0wnmzk7{e@ZHfx&lm9 zI1@gcR&4PKNDEI;9L0 zttgTD)K|HfeN(D14)3{@K9okX_4_k496$beFk4aBL9^82px_s9cb( z7ynz_PW0it!O>_?2W8@qdkBX1-Cd9)**~XYH}TWsYR8Xy%M3nWeo%TOU&%6xC+rn` zVwV$X>EKdaTBCZEM!t7GLDP%(t!**7)GtI2YU_YhtTYQgt6L#%VB0H1*5icc^C_-d z=X-Hvn|&u#r1GCDE|kMDgfP3KXuSXzyrowVr|ndAYCb*nh1Fh_Zh3yXEKn+_>|r;k zwO?vWyT<-YEIL=h{1S!Lwe?nMj~#yx(N#3fkoUU4*ENPxX9l=LICKx==8&o53Rds* z#5zu=dM(|##^|5)?prG@mryhRTy`1F4=yV~=r-D|Kjd4tXdC+MWmvFzmO6$++dFfm zFtLCibP7)ZsGy>ck~JT>eqgkZd$5TMwa}NBV!jlArK>UMD4WS#06 zs>4Hw+>WwE9Y*oEw61oG-bs#uCV%CSRIP&rE!FTNKL0y9{b)yQt&(4c2LvK=G(P#% zx20J7TW^(xX#cCOR5<^GGc0{%cuy`a|Kp`t+zRiz<6IsX5d1tlcC$F`SAwJD67RPuKm&I%uR&GQ zX3DgJd8l5T!OnZ##>*hKL;0PuCxBT~!h4q&+M z0-j{*3Opew-IVY@Khok0>}cCUUaEcUSvtL)ZYZZS?eoEQ0R~OBu+S`0-6tD~d>q&3 z`AaPvED%M-<#GE2D7eht6W2U%?4!Y&?dmIcBH{VnfLSWm5^uFLzbFIEt^`rd=S>ftbEM0miFZ|%cxr6cbIFUldS(=~(FK?9D z*yEe4MiF5Gluc3S>BCm*!x|V9iSG?wefHq~NGb(gMfSqFYHlnDkY#W`65Gjt~U!QiWb{NY;dcRiD4d0GCW#mEA+$*Yo@`RH|1M=M9k4k>Bj^A`&8epO1$8*~~FTxnl&abzLTV?0aIxU(BPD9Uxg) zy2AP)Bykx8iu;i@$>*F&vOiW~+P~I`_F&7gYfx30=Z-n8Uhl zk-YOjqowfaW#7G@pGn*+A{Tf>_XH>9Ev;3Z@Irp3O%wAQ#7a~8Q+-iaFS|WfO^$jz z$M}A#7ZytkbY&lwd;)CGt3@P6OXPkBYETr4Rcb86#vJn9G$=-MQKU9lFSyndp^5^y^U3K; zaUd?QCuRhslHp%T6Zwr7k#zG{rIFvnRHoDhENo18Jn8C8i&C(b@kc^?kfw+Sp`Sr# zCvz~ZBnj>TIJDDOqdh8{-2oEtHX-l}Aj|c-ksdy`{|#ME?{;ah0Dl)c>((jl>;1+z zJ0r5ww1S9r=JO_dFXb}?Ac zCR5waWZ<&%kvDGP-HJLe=;27a&GkIv8;)2)KOAG*No3i53||+4BS-to>ZtS>Mr_ zrVE}u-bniQ0dhk)u9Ia&uQv4}b0OsoDmypX@HutB=8V5g0s#HSdjiobs2gURj^e_ zE-qCHyROaAcw1KwgJot*gJrhZqSIQyD41T27i+)FiI|PCu`CixoT_CtU;hg03GN;C zM0wbZT582R*~4(OJE}!RG;eaFmoCk=-gCzWRz+T-$=@>&rU(UJN@a>Lz6)9)*Z5pw zIIYK9{!y_&!y))>(+oM8I%x(Ey`nffTO$ul#YJoJHxD0QO~m0lf?TWl@Z%xGeB{2Q z=w%ZkWL}=iRi6O=LLO-4nVO``DuIbZD(*~qxKcK7U;EmT!015e%*5fw$YtUl*vNa2 zQ!bH%u>fvhmdXupx8lurZijOwTPVIgsRw*GM9qpgrAEO#5W3 zCl8Z(4-=t&DMtz5sfn4@;!PJ})QysP0(_?0to)sDaM8?v{J{46EIblklw1at+kOjM zIhu!>G}F84U@8lF6=-H#r<{3Rk(Hw`AK9C4U_7i2=AS=`sh1h;1yGb~pG}7hC|7Wc zbxk;hNyWTuJNP#OoAIC2^7BDACv==*4?eH1M$stRF}*PuqQ2eB&+{W=Pf{B%N^s&v zR9@}!9}-t zTU#)oYmENdX#wb*|M7)MJ@~Ie`iG%Z&Lu(mbJcP_cduJ3?ml+~A5RYFD4+?wRBySc zG$zCQn&hvqqq5yA-Av+&fJs#TeiZN<`K-Xjc|zaT79|+AiDDb29811qPeC?`;*&nx zI~X^zn?)mWkXs-#TSEdbaOK>- z=a*Q%5d7I0WkSorTqg0=J~bEmb0$NC6_p~pu{lyZ@k*TZ`uw4SV<^e_-y&C4peoah zrZVuFtJ;x;XcTRl0nj3&tlNc!<4j^a@DsCR7e5&r;dXx9l)UJ3_|BVCy0m%9E$^fz ztJGM~u%>VHvJnU-PMyVhG)Wt6Sv9c%1ede>m3^hHCc0yS^3h)frvVz7QAgM(dO+S~LB4EmCGZV=K?@;Y8J7j*d#s z3}S@*t+~-6*%WkcfmDwk3=}o+ITZW<&e$^CIcJIDDr{|SFX)9!&y-D#T%J^)aAZ4d zG$%$=iGn1JaO8JM;n>5_Uye@zB5p28+Wk7kvK1LfHy)I!9nODU>7ePnYCVEpq)@$` z;p)fq`tOP5it=Ex!&ye(-EI-E*+0ED>J+>WM@5s_)FK_-As4Eo?orq<`1=-s^~-Ry zi}I```g3dcGV$4NdFh06jWw7(Dp9!%(n-osqb`Cqgo!Ka=QpO2=Sz*JJRQtAREg+9 zs5-Y092ygj1jE~$M*i)4k?e|SQ;js8SSl-pJEr(L83($C?XC_g3xePN^Giws=6lDQ zVZ>@i35Q;A2%U!AYAN(tiy@-2DxgsKLxls0y{kB4R|Rx^;kpXD1PI%_w|tLY%9&8l z!bYuj5hQt`{%>&lgxhRC)P1r_PEO_^sRH-->&Q`z?3W~Nz-UzBY)a0tv-GkQTtiQP_gk=6>? zBk0h#+9cbUxe`lMo2V&*%(`e663`74$C6NyY_ST-mJdzvJshuf3~DGYObOEb{a5fB zsRV7ECUMmj7sLRJPNBjgK4LmY}?xtt6j4n`&EgnXiPkf0F8=iGtuaB z6nT57do6_3FaC!#lXK?e!p`uAj{^^h-~HxoiP5HH8m%)@1Dkl* z#qIV$)(|yVPj^#LZR{~L&q_Y)q{^G~m4u0!OG1Hlvv9U?@fV@q!|Mixd(+NepD+i4 zKhhG~uTep{@emN@AW5hA%C*<~b8~l4IHvPvhukA9cK{6>OU4>y;1Y3>ICdSg>Y&`j z5rlp1fcRPN=Vc0BKJb^!DS3#Z76h-_fSSf=n-*evn*s>M>{m891{H~2z^a4H)=PxR z?5(V@vQHj07jv@fKT0I2)qH3>n@6Ovv87X*Nllu zKeIP=hTreCWJwx-X8!1TfH#e+4Y1WGu$@So4HP2^os8A%mYSXbt`{m6&6^H{Jo1tj zxNW3_$o{Zce%@aE^W)NHJUbFt4U)b(#Nj=!2h|w6nB2bH{GjkLAU{E!e9N7>P7tUu zzQ%SDkTkpb1i%#6^l~Y3aO*zCu>S*uj8MZ1s7Y!L}c5?iZp6j`q>*e|6La$psfS7uEA|z=!I69*J#UZrtHeRDs!8ZU^nd&ht$Zk@B<; z{s(-nyXpg~juNOnr1M^!urxFmfr(e-6z_`kz*cph+cIEcO|Z#ZT3#dM(nGx@?;#-~ z`w1Z3cn>5^GdEdf>QDvJyOxA~8IwdGkU~&-_nv(kdqn{a@IqX}(GxNwZ-gI18?2hTw4U)_uo;UnkWnsPMH0_G} z=WB+{s(liJ0=%47SJc7v!=b4?Xs=HlA$+6XdKQjHjQm)qd|Q27wb>B=(P>`*;j$f* z**TaezLRcg!`&dHfsB%RdnD?R6ZB1X*V#1fj{cV318;gD9NE=5h zy~6@t(kIB#j=`tsdP)4oTtZgx$o-NZ2{Nl$Qk?7gM9tJ^jJ`$5>q};(5=e`ac$;?~ zX(Gn@?n2`Kbp(-WG}O&ixiKeERywt6Z3=3TL*(Spgn+z1fL45Xj;?xt#(c|>>j$Lfpk|*1K^Y$ITcD`d9JJ z%%eW%94j9G<#PSol_E~=zt?lPCKsLz&_a0$CV}7}ee=4k9fMXTu`Juyx?@Qb*i_`NG>Dz7gjaBe-wX2Bwx*wDtZ90Ds^< z&Vg70G&;@{0c?Xq^idMZ{X_Ee;^x%XSvFkBR)G=H4xATjBV?Agkt&5`o<+NKndx!u zJDGJX`Lngo|GF_&l)arB411eXOypeEM=;Yr#MGB-$O7w!cEW_&3BcdxscgK;GCpV{ zOa+8QZNs`uyyR9HL$lI$3>X!48Ug*te9I9Qu_WqD7cPb2GwvAR*<*>XP%U$Y8%sjr zw76163m;Gj0<-hrD6p}hM=z#KC3urgph%go#8=j{5ICmFV$ouMEU5j3neO*@XHu~S zRN$%?PK^|Zfr?NlClT`;BZl&0OKT_oVa`U_OcXQL_bj`F*RhGy0&!s~;rwo&UaoV( zZImt5u@M*_fpPoyg+WAh~=Atv&xOM+?UqPf z`Yh>NEOUXxY_IW>?6^%fsIRvZH+11=2mT_6MNzqVP|vjBrHKI|O(w!6WeE}M%5{Z# zf@~EpaqzLQG3Veo5a8hpy`L0p;_G{Xo5(RzqdDXS&Dm;+3lvqH-U)b3(&DTpQUGNu_oLp$l-M%SG6U0p44~V7taR7h$m*L z&Q2ydB}@(=xXnSfyo#rj$WIJ-_DfYIodN$@lEn?m(RBRHh3qQ??kn`2AR!jnQ*a!<5ev0s$*J6(~ajJHcv zDl!ajGU<8HLdd?nl(`H&7};9NMmz!5IU0?fR+Yt_Gv;qdh3YGt*8G=mov}UdW0rn) zS^f!+93j4%k2taJJ{$Pj2nu$tf3bwcrnsFBL~Xi|W_&jYn($ty+20%BD!FwJ^7!{7 zruV}5E*r(hvucgoDm!r#P3Z3xOWJh7fvHxInwHdSL25;*@voJM>L<4k?8k)s?5ku3dh6Q`y6&1I2X#dGHcyGWz1;ZhABLna{ zifo!CbOq1kF_VHfq{6g}>UEHlu+~>R(!*EACslq!WstJk8ujhqbRR5V+UsbE;Icct zuQtnBBz>n6FilN(uimAv^h`CvGl`0zOm}Z zE(-*l6E2NW3C9<&kR@bci%&`mPW7d;ZM;}Yqti1To}V4l_8!ks`dW^#^Jje{RA zHp%pJV7VUBGUGGUwue2Q#C%Dj5=y98Ve-|SQ>S~i`^+F~b?MsV+|x*&D&wv>kvX2? zbfklDI0F9lkhlnR)1I={)uP@8%tY=ybl5;U&m!kaUVr!S@F-1g91$4DWP{txsNs?E zjSzQ&yQWj-2ED?FsFvn#^XeXPVOVF+{%fblRzMJHp9YK&QNEidZ>l-&Ioz1#T)p9* zV6|N7vFl3bEXPq5x0G{rv{`#>y5WXkF&|uu4#9FxYU(05Ta6KD-TL5AD=y7>HP8+n zZC(L8YKn*0zr7E<>S{N*Q9GAq*;em$VH_asCi*-f)eBcW`M&9bTg`dQxltF=WZe0Z z8)Dj)B*nJ-cD9(wWP}cPmxQ-qDjC6w_r>{>{+8N~X6Yp3!fxOe#*NZmMR)GG{i|D>?1nYc^;7{R!3`qX-)vb*8vd0E@ zB|>zlz#HU1D8L7jAslmO2<>c(=fGxMk`_r9j@f8CZ`r)*D}ZBf(DuQw1LdZXrt2U{ z>H}N&kn-fe{@JD=y*E&0FFtfliIRSIxY=N>#|wFipvG82kzNG!riZq1)!{CEATlxW zIc6>^M{JuZ&@Y{8>>LhRTB(fIn(TtWh#-X0!YrITHtT;d=hT`SS}JQ@7?r1(nPEG{Q-;#%HZKEpf?cw8Sy?!-Bb=)| zj7VdLJuTW5H|A6isgyf)0^#mOyi!1A=Q^wY<40a>2|BBNxk?^(dy zjmvnsgY|>^A1Z*D-#%0`6pdTfe)ohx?C{P)^cxigvPdY$VvD*z$F07NIF*9T)iv-e z6XO?)-SE(H*s=9oXIhbdD-XRzzR?tN6 zAfxKp`!TwJ1DVPOQCCE$pT8RV=oB!hEcU}NPmnqC1J7x*!Fc0G0uV5cmuMuvgpRxU z?GvDpGrjGQot@Es^+uvnZRVMXTV2deo>b?HbisJC@yD8kUj8ZMtoI5k`D0W5?fR(wTTyV`wiRiUN%cCsp^bH3l!{to5Kn1 zkEeyesHWxbcVxemFcAXd7T;E0D8*U%9gn&RE;7z?kza$nmU8@_0KTpSeT+r)u<@l_ zNZ^Gsn$1Vuy&VVA4#~WS>}tEVxz7b`4kYANi`zJV$9Vmwq@%6!wQRLhIzm&{-4upa zXvo>FCmJ6Nd)`y}(r)`c0gm+dT6*1{6Nz?*nzBCivN)Ejn|cmQY+7Wf&H7$g^^7L; z^oxLbzF(bH)}^uQ=adz9R$vq{+3_JElK^oA`M;WK5|f*CYv=1b zLY5C{roqh%>+Vt$LJF*mc$!1-rlFeJVQtv$)xRXw`{s$K@%YJUX`cW+7UCig-RtTv zmzr8l-TEnlQ{7zan+mUYmGgd3>r_(LjrshjxMkSgh@xOJ40+Bjt$#B^Z=L_=c6=1D zZ4X_M8&kvY?B?vRWoT5UQ~e>*ldpjzK9-e9V})5D|E!(3?oHGVia`PyiT<@we-PCXC# zJ^k-<=j^{)JV&2vA3X;SL75d{-Zm!SXtr7Hk2^qqCrPMkx5eqEnIdj2W$hYUriRb!CXQPMh=Fz(} z4DGG$9nhuEHVt;o9fp&mCzmwr?3tO=9g~x^D7SmTWJvzWU=}3ZAMc~YI~5OvZC^jd zKi*rg;4-S3(D|rlG8<%Bs2BSSH7u)Faw38Yq&BSOzDEYd!>7gWe zk-EuW6y|fpHKPgCZ*C;(9B(jP8WA7VYfb4E`=(m5Y;5M3wz%MMM_t`NPsCAgJb~LZ z+AA2tG+o6@%SDn}y`T^_#T;i7;M1aPt}y$zK;xV{`d?0^li<97v8vGiv0t#^o3Zjk zq4V`hH;2+NgD9nm>oI+oWOlh+spvrNVnm2EcuT21sqgR5=%nlC4tBfNmV7ZMLn~<9 z0b_eAjk>690X8wI9-dZla0u!ll%tdfF->!D5gu+IUMVw~FM_3b8b z0T)+0MS*@xt59o=3rC?APdt`!(C1TO^cqWhR-{E+dk7&2jyus$VJ(`t?JtT(h?PB1 zbzJ=+4wV~nwy>0WQa zR!n4tBW_NprS2~YB`So=1vQQOwVkwkL+s6?uOGz_4pu@LFALnN{S;cR%}Rn?txRPG zm%KnO;+~)ot1KXAl*~XV%P2mF+of>$sj0X974}#x^US~I)`rK4HgjfU^3LXn0U!%5 zM@oKC<69i1l<1=%Y?XomGqKrh$}vvSHMN;5R?eo(?KJg7i+kN|j|dP>y*SwLAzO8V z6l{?aMP^C+%lV%VqWru=J&Ib0)7aoVvMVR#bB>U@Hwa|q#yR$8P-<&uJM#qiMdJk~ znvTuMg3nEoo$BOqmG%v!v;e0AW zO1}}>3{@=9nby;^^Q3_NCd+!+56>x~uUldO%jBD>`aO1+o0=X~JMA7q7(}+^!b|xfc(9uNRGyw37Oe zXGQ$!u`!0*qVVA2*;Jnw=J0n)94f3k_Y;S0`7SYgmU39Se1Dn&)eg)5AnzP&g{J(G z{qFn+gCQAvT0HdfYRD<#CxB%7zQW~b7uju~ba`NBQ2xS&AB*>|gIi6tKLpzScV%}I z6ApH{&F^V!R;6gV%YG8~-cXfTh0kad#2SSMj^^udsQJM%zVISwe?z}xcZgs9mSAT$ zg}m!xGJX5|fWu`m>79r}^o1F;+*QH#x#&ji>%oamGfN++pjW?a<1r-NG$JskJG$lH zj~8y_z4x<`F5xt}^D4f4sh6Zk;zY*yYlF@iaxR=~>Xj0ts$1i44({saAxa*48y~Fg z^MO%gC55C_D|3FoSGz|a5m!Y{KDD@2K0@>6i>G#2_rR76(WErXFC=5Zhhhi2Nqj!d zGoC&$Htv>?9!KawQ^>;oKuVR?t=EAH7+C`GquE6N%Kk#Aq2@y_4}p+rk;=(gNNfw- zEuxDb0h!mCqko*5t~u^hwkTDG{G8baPa7w(JH2(chA1a01?Z*i;AQI3lX-{n!#EvR zp3OMWd-{>MBlWR2nB@r&5r2IFIoTX^0{w#l5L&wqzQKQi;qNqOijOOlph%I?4whc^jKEE8VCk1Q(( ztHkwZ|KOb-hXe>&HU4*Del5^Lhj!B%lRDG7EEP;+MCahT74hbsPp&w|Sb-ZjcoA&cvY5gStVTB_?xhc&1=gxEkj?nNOb)#>J^W3`L9i}2aH^<{^fi>XKwQ^ zu#VD;3sMt;XwcJH6d4%7$KBjth#6Tq=%bF&H>UN`u$;MU!%M34%?Z7+ZR&qU>DY#$ z8*t|P=+PrvH4OW8-hS3e#FV9&BmB7}7*$Rqc3$w_pi_4NhE8>e+K_KyxNu8e$9V;g6f1h##rVa9x?lR@WXH^^4BsIex6p>!W3cseClGb#TG) zlKpZ4U&-;?d0SuEzfYa5BTo`(w*Nr<0Fgv{U(jG1TejjjPsknj*|5gQMf7lUg_C~P z8vLjezK5oWy{eJ|i~OvcjH2*qjSG?7OGA#}zgP2H`gNThZU9#-bAXp-l{uV!=(fu` z<*Ih8{rXqx00*)){s~}j6UtL85>ni42}`E1i;=RGdU>c%dacoCWk9*$5w+IFnp#uF1N}PBXg#zw*+a@6 zyu$KUnowNtKEaaoGrXq$&nN_ZlU(rl3Q1+2?7x=8SYY0Ewi>=}D#^ZLtT}zmb^01I z_s0m^;El{=d~Azsh-Zz;pilzuN4{^gM!|g{XyUZ%1!1D|@l8w0MVs7YCM#n>RVMI4 zbJHH|cTnE`U-6#<6h38H4xX;EdXA&K30eP2aTy5Y$hT*34sRpznFHB-V5wUCB@RM+9dx)mQ%}K!L04r85@Bp`%lV)~} ze1;P36r~|o$8@qpXs1XL6(O}br(lG|;hGORA)1KjV@yybyzSq>I|0jU4s$lW?HSR# zc>@DuJ()708vs6z(?uyUH;uuoS$oB<{`X; z^NF1I-}5vEg`<7+v0)H`^cLK`PoCdnsP1<2HiAfS)#Wb#4O=(Qx_bZE^=UGsu-^(Z zXJryouzbUX%?=R%L;7xfuF;=T`zKU_zpp(c`rkd7o4Ca0fMKTuZiFKB*~PIX!D15+ zFW+&0zUSzhjN2V<@=u1v=CcJ07uOQ9>!Zsu)l_Z)3}o5*>Brj!oAR-FUpO?sCa%2Z zKQ)gw(rbH#q(7f4_qIYG*q)m^$+(_mQ=)H;R)y%UN(Q2XbX>I5UbZHLgGstP4!jZCrH?`}qcalYRTym43+cem*=k9f;LJTXpX3za5fXC>r!aEZx5C!3$8OZPdFkZM$jwF( zy=fO4UGc5Y8SDtVF(l}^I`LgtNygQ)$)C;HoGRMn?Zdrr5bAR{jXmW}S{(Df*-OyN ztHET<7rlO07?VGc%s=WJaP_h*5dmNFxAsM{6?cWZiF-YUz=+@NfL4C<$L$-Zx!@&- ztk*y?GP+QvUDiYXNysT!w6A$zn=>K*Ku)SWZM8W{?DXjv_pBb)v-Iv9Z ztM)2B6Ryl|OWlWY$d-1*sV(`#RO2mZH9Bxu8I%O`xK2jCi9D`iMyPX+^NTu}2~Hs; ze-ji(F}HR#r(##+Em)w{^OcRC!91Fx6&8b^sd-ravSLx7;B-kCy~lN#(1n7PoV7fd zqyGZsqLKr&+}(}_3e2mP;_8~^mYSq)f?u06+#rdPf9b7ZJYxAxn?+H^j{p&-oA108 z1dgRXwHg}=emk*AXuc&cY|-J1SrAhIuuX>F37(7fH#j`9&KpT+k7QFGm9&ICV)MIF z16}G}36;g%hA4HFz-_v(e2^XYj7SQnx@lR?0z`Mi9%duOhG6-!w z3fFaDi)dKst94%(Gp&HRG1BHwE!2sdR6*e0@CsLE>WVWP#IXrVPG8~rl1&Vb}J69VGQd7+7Kn`6XZEg9=;gt1g#C-N13Tvz!{nlY>)(>sm zZiR@r%940$rPJOg0Kt_xaXfn@ABHmZ%hj{>w2SSP{==;uMa$%Hup?eL(W)Qi;DtRY z5w`hX33-JUzJ*&*Q%A>)71-d9i2ZMxyXQbD<|Qh1ZR8&0IK+G|>U@ed-{YKLZMsIB z*X@U&tm5=ld~sdU?$jzw<8cgP&LHwBlIA0pPXp(-^z~d7Gtw_~`h;)DrSQ7wj*R5P z9w!Bt{qsvo2W$)f43s$X?VktL>e+%y2>afsoPwN_7?LnFt`DuFUVn^SSTD%~%6o1V5w(7U0ILG+M3jD!p zN3!(s*_+4qeWTCILG$7Db&6c-@(N?7nK17dTRM6~xD-QT$l}^oX9&&}H;H{0E94!L zee)=@Hch{NU~0b|ni;ksxhUXeLL9plB?dBccrieoSE?Lq=NH_1w$?9D@V@mK_b}vf zaCJ!;9FX#AMfuGiRE7qecMu7^WqOs3c)U)`T39k6qDNm7Ja5v{G%JYrS_M8)1^>BI zvXQMYC8vAuG~~)OgUSH}+3dmVyTAV4uB{A)nPVeW*SQngX~TDHJ6fi=zMX$;(gyM~ zbu9gyt!sE){1uv^F(Y6c@j;{hA?Z-=Ee)}Y6^l0^3$d%!LZ|W>fz?B6Q;=0JYtY7R zNt7ExiRrUk^Gwz;ml>}+X#uZITl1&m5nfRe_^9=1m$;vh6pr8YulDqvyxo%0g(J!< zda7K^;+D2B(u?sdH?B=5-!*9S59d$r5*UNQ1g&$`TTw&(q!6KQ>o^o2Qlbxubb9H` z3B}#^Sq5w=_p%IiGDEpXv=P~~@klWcdheQKG8MTj zm>wEI(F8@*svIEm@kza?>S7vQ4MTL`w%1O?iOTw(tT|OGNM3PSZ_H zauXxy31IAChSPr5Js3i0ub&Z*oBq2VcktUK#ywYci6c>t@ejs_EiCfR_%Z^%lbEW+ zzVE4z7Oe=o>Sr?%8ufd~`!rSnMmXSe<8xN22r4&TT$H1M3yd{XWe3*`vMB0E<}u=K zFebHaFRk6xPd?X0(Z-&K^r&hpQUp$~nJo2G+_MPlB}49RLAy*VA0G>gdhFJn=8tVq#Oct5VCO^dhy_hl#ZG^S+<*MgbrJo$NH*NHX8~Y zslr<8Lsm`J15lE}6TtppwFFwp>C__ivH71+Vu zOskc9^`%`i`*$M70V$_rA;Z&PntCYsaHP}t)6EAF0;7_~oO*#F^r8p>x~9o=zeT!~ z41;UhP^tCmM{0e9u|>I}o4@TOl+tOEWF~BHoXBWN)z#8c!-TL`PVNQ68^Uq=Q2G7t zV`d6C9=0h+TLB*2f)+EvR?){13nb{MOsS3f`X%XuP2d(iX2<*`>=B#T&1>@~*Ib|S z+I?jryGqjPyQ1t(5!c#u`k1C1Q!{R+w9lcYbN1qc`XGD@x7C5$yfhO(@Twu)x4t|kDCH1xddUGEKY#y_lpWVdOgVtxmHJbCb+M7D->$4XW zUqo8pt8RP9l`&cuDr6bF6d97xrEHECck#-8C>7T)7aqVPcgl|7IS___&6wvgGCtJg z1{?R@eGumOy+9~`zeV=70>fa|hjpV35Bo@TFIFIgLYZkNH-pFzrV{u2Fd#jQhy#CZ zMS;n3NegGlbs%@WLV;)D#dGPDeGs}g_baFD#M!BdYGZ`l;nA@hn`usMl65;aMIW`GFkh)rQeMpkQ#x75tFaTC9p!$cdv3V@6)|l8$@ByS-r=1_}XQ> z8XK>X!y=qC7N2)>=veh!$^gN)N;<8t_6R%w`8yrtI^hLusdtoBZ~L>IzyMXs>&&gIexM?2RIfg!?&z-9&%i_tUVp>NGc}A|c5tzKwQ?*00_>HG0!vE^EyK zdf;faG372dG!y6Yn}M`=#{yCfOYlRAlI^;hoU-4i@ma2>F2`1*VY=~3LX9|>pE?77 zM8+IipK})Us-rth9+6gi;qaE!ckvFmQ8m)Uk+P*!1}>p2QLwpibV!SIk)8$RL7+jK z_wj}l7nIy}3-q3wbG-|&?aZDHGn>?HAV6lczMysByxYN?c0U+^00jxhbp5#tP-krT zkmcY~m`HfYvqEeplttdkPSw!LUe|ub?QOC*r!KqGy=1j@m3I?U;~n%;;h(l6ydYZq zJ$R|kMUA-`%4;hV`rUrLyCqwgts3oA)D1AUhq(DwkZ49Y!-z=T4ARz91lA`>xZ4vk z!q?cP%N;W=1CvSajpGe%W=e3WHVw9td_4bT>pJ6Tk0p^HT|iZQ@C#4V}pT&VVp5I#T9MKEnK+w}7dd{x>mgsRe` zRO}AWag6dDmSh#sHa-E) zRFdmW9{W}?y6Ri*$>tc^W|+UEn30?9z8_s(;)|gFj^atuuqa)O`zx*X)SF zA#n`_ytCwc*Sr2I0pysfrC<6}-Oa%`!0YB`Iaq}%go`zq#?tL1EqCiukhqVJ#Ofya zC$5OnAgVZdAj|nlo7;|n<=o)F?2;=J9@83Km>Nn#a^`DUZ>mYPGo3+o5+_EdV`DKM zvS%7YXquMf5!&ic6T}<)ben55ik4^vrx8QgFYpONuF&x!#*+bHK~okmRMqvZ6nH) z%_NsqgvtXc+ajkt_JZ+*x^IO%8{wO)ZBoX0()-2<3thz=@wLsh>smoFjH&Y3H#sM03<%uHYvtQ7 zjQWRyyi=-LTUds;x(JY@=r>;9c{*el!AX;QcmzHv0PBEA;Ul^l?bGwg%??=ai!%8*U8Ecw>yR z$QhPQDkEd${{VxUwu|C_+C#whHyRzrm*Gzo_=4Fry*EaJICRVCWrpD`Ls_JS& zqOys%5T-e2Tq<86Kd9PPmErG)I&3y}kZAfHt<=CpZyLvJmojZs!l1_XDu|!}P7y~y z03!}(%)+F(S0XnixQ*>u_xW)WrHWy8jE)CxdRM`J@KOH&jsF1fk@#E0nhm|we{R+^ z+rJOnMkX%}rr^TrHEo)yyXygQ^_Y5G0ot+{Px zU0U3~_A@*wJaD{7NO>9=71;q|7HCidG|?mcW&1LCdi(a1{h+)}u3cYWYI15;O=F{_ z=^Ol+Z}&|z3@l(xaRSKVH8IB=5yXr{4TV}-`UM#2bLmYl;a!)CygeU-?CtH>PSNzc zIbO;nAHJGP*rB!OuglaguAC6xcJ_@hFYuI&*X19aGy4u{&r-@T^_EXxx#^N%? z!m$0?NE;iZX7(P zD^l>k#D5QHv%{&|>RN4tYa0N`G?95xBr`x&nmIQtPO&nbt+h{QCS$6*Hs(ujBMad@ z&)7AS8r&;F#U?Jw6(kY7tGI?a&LkLP$`Ei#I6uCh@J@MSxc!PgAIh=$es-K>lPR-% zG_H4n{ni*GIUz_HIIrjb0Knh5)NP|pvNFOgztNlbW>}fO)$}_E$Xs9?X9JQypwHM2 z!}}-pCiq<*{{T2m;k!G$G$0v@p$XCqs6He3FiN?hSTi_HP zyS`FLJ(+!RMIH-z6UQ3j6%88c+Ixvha>UJk3gaDusn31>+?9+{wxQJH{AT$0oi@_J zzp|{*@%`EF4;+8CKPr_k%1AaZOq>JBA5LpK#mvv~dgp4YmDBB_&Usix_9&ZuQg!*07bb(W=w$KNa%KxgU8f< zc7MS{w6mu8^Wn{`Hy2(V)h5=-vxd=D?ipiyJJtUHT70&e(%v~*Reazr-ODnp##E2= zds>2fs}_#XqTDddCBbO?#rc#tP)^1pY1#{6jyb^}oL}%$e~9*P`$qgC`&Fhikag`| z@uE{5cZ9Gw^X%_bP zGovP+bi!tnDV9 zrdZ9r?VXjnEQLZyiw)z)lYz1`o&yny@jLd5n)~A1(fD@9`(oC{OK%$6M+9PD?SZRV zzM*Y+L`cD1Z)Hhjj%}WENeMfN1Z^no{eN2^UlTna_M!NvrTC9f(Pp@s;?GZlJxf$u zhKHrr_{{Ut$fST9DpN}@$HHNF zG{)36ji-H=2(m1)$t~8%(%d@lHuDjHvl_kQZvps&;^%=aw0k{U!oCr=w|!5@)9x-L zH#&X1x-Hg_ikABs)uD#sDS{W=ix4iaARv3k?0*KY;7=G`=&P++_*+oau5WeUAKu$T zx#7JwPQJ!G8*r8+Gu}2b(!;0oS4dS1IJ_GS2-YD^Oeh;?1 zT}|beR@S^ZBQEgKvjw<#q*!iKEr@)?m^^t@5w?NNDzGT+}vDF@*K)ML&aV^_=|3IeLq~fdAu*;ts2kD*Dr5o zd%F+xTj~72HI`94Z#BCz#c+~EBs0Uce7OvfMwhrKyIA^L;;-#f`!wp2YF7UM4g6E# zo6i|)c0$TKEelba!a3oF<|#^vVF&DY5j1{$S1N#`z>fRkU8l}`ZTnk%aqz$F%VXlL zYgq8@``_@FH&FeO^%eBni}ol^gHWZDEcpk#UzRG|O{qG~zdZ-brt^StN~&@Aqbd;N$V*#J0Llzo=en z)@^I3Jd+}(bqRd9Y+eyA*wMU`e|Qn&iXh*;w@)SHD%Dm^J1(J4nzh;I{uKCeuXyXk z`h|~(*ym_VEW+N*>K9gw?H}&c2NG>mj^TiLUS|m!c{y*tKW9&bns0#``sr31 zNLTH#=r=|Y945>yw3sC45^qbjkbnT%?a{YM=AR0`X6S^zCjQL4w_Rgg(iSQALbg)d zT!^EF(8h}2=*6Q9p*67G%R18m)jdgINf#@Vq zanB>T@{?_wN;y3g{`AmqgfKtXoak?KGJZ)IbvyW{1%9Vg<@N!@MnuO zt#89}T}LJ$wU2eoGGV5%To)0sT<%5rwmRh)E)=#hX?WYlx?hE_^c^|ly1bLb8n5kp@0;)5W{yQbbY0+kRpX=6A*)0L|h{0pQEaIM!V~R@1`zh2-PT`%TISqcSl& zRx60?meNtaLlm>yJdHHYg}+PuSK&QV;)lik2Tr>d5=(0x?XIORS-!<~whSPgD9Ws; z(_Jp$eq^^$Fvv*w@BRuCrbjQ2ykU7~1XXe7|5_O5ShS=i%zwN3zb}NrpRg&OE5cMfNH|Twr5vFa>@bABa~zGWeCOYm<5Q zw=+C5S$^+-*!M^C6x=Y3gJT3Fg=T(saNoDT@J_8EY(Hdg2gogBwOOp%MYoCfMQ&Po zq%ok{r8XBSwUjcluqD3xD$;JM^eIIuN3=@q6ekK>CuzqWNC%In@}?88s)h%GN%!OY zdsKw9Tff6O{{T3u-e2#=0K)X^_?N8qhZW%!4!S>c0Q>2?=6=WZpQ^*u41z<30nM*@nJz~uZ-1ZU!{!+Bsc zX%={JIXRy2b?L_6^O~!0%W@<}W%XKJ^|!`e*!}G!3%M%2@8%(Bm2Z06f-FMJZ@x=vNY1-HSJ7k#?CwNO=PwZT?4( zUe(t|b{;wboCDUfY_#a2nVczEAFkIt{YUxl_-I+jwlxQF?azPn(z z&eqaJ_7~H1`(Kfm+FGpeBuQ@$_7SW`<&|UtMTS+J%kvNHr=r-{e%jh(nxw-`@Xf+% zy8fK8k#};|cM;2FvUS3^iEM859mUGa99Hj=fns<1%R%th{3hDc=(?ttadm4Yw3hm% zl=muhzqQ=kM|bC3U1m0d%-CARGE7nM@^JtCYXvtJ>Eu&yrO1t12g*JaJv2SS}?kBHKkAtRr_QOTx0l zy6V^BHKVl9CDKUOQEK)#a$eej4iT0JA-9QURw@x@P*eo3b@+SxPF`q#w-1T_CTd)W~Ih)Fe-%ucHf+~nOU6PAKkyp<%D4>2P;II}p* zx%s>C2wP3@3gXwn?F^bc8nynRXF3*+;&r&axQ5y@a*8E8tVO8{yKX0ylz#EUdxygB z+JnXVwy|fS+()D8b7~g$*EzUW65kY^q-o)b(agmnH#3A=R$b8ChCWPgOp|yY_8GbO zTRPrrnr!#pBh$=lCXH%kj`AD%W`^DyRymGg`{zIrAag92UPvlZ37hmY;qSxm1b@Lf z{{Up&b5i(;Wu)7^_zuErt5v>R*;%dh+X!x@AG>RNK&mb`M5^&$U9yoQ$TV_{mn!Hh z+f$$TANy!{mru|d$KsM+9TL?ewXpGgsS-zNb}g;d8t(cWA%c5}WGd6!Lh*@_1I;fZ z+&&(D+Mf%4EqLcz(0niBF9*%2LuDQLSqn)G#k4HWWQy%)3nIn?l1i@K>$@dnQ}Vxq z{yna#;av*XTJWEZJV_>skyvS;5WFd-U0h9Y01LU}Xguc#!ISNOxR~#AA{P(l_SU}; zJO}ZY;d$1)L2-5B4RiZ4?)&}`?Ki^e2&M9dr|Q>?6;@*rJMUt zP*k|_gm)KEo!X_%^!F1^t_l5LXKZAX%ihuuo3|(|NG!wT8;%dk>`8YfzohTk ziW`Ofp8O8iO}gSu0?DMvzy+aJJAc{vzFva_U=D)nMX9e(@)LK4hh1{JLar6~9<^3M zgS3FW=O(nS*K}*od}lwHtodTZ>__mZDsniiB&;Q{?+Du?c=q~lEquYJ#UGi@=3h2a z0X;AXBzGH|rg125iRlnivPnq*62 z)ImsH&9*gcA7P%Nz8C$id>f?rgX4ais_Xg+>ROfeh4m|0ts=TH$nuN31bFR+@Cn~@ zJTo`SVo2UcD)(Sl+=5N9$T{ToJuBtU+KwF#%j2bnz5f6S1)YzHJOiWL-~FEI(&3`8 zx704<`%V0QTyAiheWs!$bNjIKmLxMOrj?^~6lG?9b9_42E`AOCVfcTjYcgBwQfeL} zn)6pxfAo9(KUIfPS1K58SlZlAD6}9tBz)nW0sHSy_-kt}mEsL9);n9z_(!$EIHvyqt-sA0RNlw3-`{CJ3Vs53x8fWg zAMoC(s%awev@3rUUC(x!c!D%}<)6-HmU&hft~Dl+N7}BVG7ZK_3s=tn0JA^FKNoo7 zPls^po-nu8w0{)cSn5!VZ92x<(*AbwU51Nxmk~*B>5*%D9I553Z{{>>zC|FlFWbB0 zqkK92hxH#0*ruUxr|aMFl4^Q;OgBiDx=)8~F72Xs>-){oMowFXWh9=SCf+Ll0ES=j ze~0`#B!^0~_*>$b^vy%<>>f*$k6-ZKp#;Pz*fR!>_c;tz#nJ^_0g~@RV;v8yz7BrT zx<8Gy3v2V?e+B7&EAc0X^myzweJaW*n@hdaEc_npaT+mw@wMNJnSjVpmSR7f&XBa; z9QeQB{{W9)7-iI-!JZet(L7NekoRe;*uiyhQ&QKjboK?4Y>a-v9kfArY&dp4Vab1p zehz$ce-G-b95>H>rfE73nX}7ub&xQ4fyi+H`R{qeF#1_6M@D7vWfpx0+ zmgZeX(&1y&URdIjXOiM9nA$c2?Kc21Q5r?b+-ja4lj7CpgC?P^Sol@7Ej{9UdF1;f z7FP1K|seokHBC&mnnaR7lg$KWSKXZ;W3Gb1nr#cZJuN@ zjj#K}2X-(@klf^4!4CrXTJgIawn4}li|%i#z=?*$!M)(MobO>R#LorSLYAx zf#Y`Z4d$ERyQI~B;UduD(zN|l+*$0omRYyQ48SU+lPvc3?=c==Xw`zbKm(7|4+vSc zhlMn|c@|}o3piqNG^R-d2*4rS{pQb8+~n4D>eQUywCFm~gkZS~Q;fSGyN<@H+()#G ze;01Q=dD=1kLAeYcXL%{xL8PG{(4nP*VM*42ZPyH#@e6H8Dz1y1D~0Hc|LMF5i-p4k5YKD8f;ShvL4w=)GAHJ0D*Zr(t&1eBzi%y-$51%y`0o8HKHY&Q+p*)`uO!Qj{m0Yan(3ypBO((Ug-9E^g;7qaY4-BU(Cd5kpq?#z& zAdMVsyOe-NMpTAVgT_GsgV>Dqu6XTo{EmTcepPl~zwHLxoAI1Lw0Q4a##^bj!UPS_#?&8qX|08h+4PFN!rU6>8twnqIkm4}_qPUR^fuNhI(= zq(ce4hDf7R8?i~Cw0DtHbc)gJ%q)ZVr&aKEp0nZDw4?T|KTKHG0S23ME}p9rs8Ja5 z83>Fo$`B9$IL0f0@&5qqEAaQ?mCTcCi{lMK;(KXqgN>#z$S^jSTp9Pl+01&8u16*x%dudhSyT%=S>*c)Q{Pgp80p*$}~E=u@iv zyyb;{q1}JMD}D=KL#W4fqF;|2l`xdtZOnz1?Ud4EK#cm4Y(qMzyM3B zf59xi3u{*Q*7q7$jU>0!t_|(vGJHbu42g9)p5oxj`>&2)&bqpncU|`qwsP!FoX)|` z8~*@b=6+@PAK}&i0K&hH9vjncgGsJV4Oitgg#-rM3OoYzxBb1T8|AH+eVmKXCPkhj}OERuZcN%>d;05xAn{{VtW z{1nmF-aB1C$55r5CfW_2Ki4Dj;x`L(<&Bp>?GvKN0pK!%Dn0Fn>O}mx*8UjWcq8Ga zrKZ~7UrT#u;Y~u`*5(UX);qh^@aKt5v|d(Oh&O%@vy~JA#zxzK$_d=~qwsd)!`~bJ zB_@xkU4LflJ~Y)d8(Vqpnj@sFUO4c?K^a$Jc-6c?6E56**6a@nn*Bl3zu=eu00?Ya zHd{X)D?uYnjy@uIg$JR!2c7 zUCrn3AS}|Z3lWDSk;N`qw{>FGNLtr_hxwn2f3$Cg^=pq5-a@*3_IfRxmvY%jb86)! zj`v*B^+@f&AQg@)PZz{Qgp$jVgwH;;;4g-LF#Vc572tht*TVXQarlEo(OXiwx|UrQ zJG%0}cX!%HiM5N!c7>l%@lK&M z%eENp#w7DY?hwNadx}j9wSLkgoM$tt8Cu z*eICX#{m_7=l~U5-@1|%CkX4U{zvC${1ZpuCyFnBY3ttxZua{7DQM{16@ z+pXdfE@O#V07Dy}m4cyl3oU+}vokpB^W~0DJX97M9hQk`w$?UwmKOPDSZ$$5V^(5W zh*barS&2MyNj>Ud%^3dj=L7!$uTdDas})L|+YFsnRBhb{p{%Q+zEjRf zTm0g)S|S`W?jJ!@F^=$blSANz8gZSHi`XrFKIH$^Gl4{Tz1i1*z^9vyWb5p8~Q zth!y;Sp9_7ptnzOImb?iouzWhR(1oB@ne?wwQ!+6RnV-m{CBw202%h{_|=$G&27p) z?_k|deLbq5j8a4Kftz>k{{XP;gA9y3zdU5{%*Zk^e`6|j6a`gkQXa}Kg z2Fy$H_QB$%j$w_xzdUs{TsnpdI}YEiIUSi9cm=;&tJqd2k;eVsr}_HSo7zAd_(1%H zS5U)@7CGnGasD*#n6LvZ06prItS(YnSm)*Yy-hkPHw9mqbmF6DXOpK+I`B<87}?tx zz|Vdt6J>^DB4-J2hk3h+LJ{{UK))-sEg5j8aySXO=n$J^CR7E-5m{sq~H)^f{c|VE1Zh0IW z5AmrlCHY6omSK_!>>(w{A0&TlwEKUFe?vGR2vy&*3!r_O+OH$0OjVkbc;u*hx zXOp)W>^TRF_s<10!%Oy$h>SQJ*3Hxb~vC?c7$m%J{Mcm*WhL zy>glbkY#1ysk+oL{J+j~T2LrR*#;F>V`Qoey!UD%WgV+4==}tESXyiqde-nCre_pwv}#sk7||~)uk*?I0mUfm3_sR`@Z0M zRj6Jx9rN4iO-rEcOtN9e&A0G0=;r+rjQjIc1MC}A<(vF!cqH48?zeC;j(Dq#gO{;8LId65!!`La`oC;a+;pM_;2gz!VizY!#ps@&r0pmQEk92 zh0h|L0(pCT=RE-DHCbdgL@J=b#|OPxf%{>;o}AT2Rsr}e-;Td~txF7|B?r*;^yyKT z+SvJV+b6gA{&d+7nI`S6o`Cw*a{+Z`P^XQ)l=Xf-U%lHuN~3QPL|%C9&NI@b5-d;O z+nUizk97$lKSgBb~#J{qfi9QQZ8=<+cc8oF0a;l(j7&FqXD&m-mf=z#qHD zKlsX=vO5w{gm8)03`&=$Zf0Af{-38lP05+skz^5(bGBzTdKeDTXP zJNBG^0fI2Yx0sxB!BY)^?s*5sFNiSsZ}y67X{0#+0EAk`eVI7I-D-^4IKtsmKRXPA zg$L$fw%$Fr)F#^w*x8sO+G?HMr@|12QhTM1@lgRFWQ_p_I#ySwo<83D1 zO|oC<`G~LXjl!&o#GHJ&?s|SbYi=(d=+}VxI;Dd(ylNc%p-4s}cG7df0OuZrR$6#B zPnLbD9FT2^BUge=vBsgez|Y~+H~?{)t3HXT_;56uY}Pu2(&KraF$>De!sUQ${JTKT za#Z#pEps)vTIR6$i=d|4hSs#NBW_kfa+_UvISrnD@Nx9*avvRNvZDV0r^3poKO!;D zIN;+Qk55BcY2(JQx!tLy&4VMTTZy9{Tek!dK_x~=KQ=btl5hzdR$9a|+s7!lwVFGC z959h~xFBtAI)lI-xgDsCh09`V4;^UoTmJ7+hG5@t42I7<73w+@$g284Jo;A~aua@v}{9qlUoM4`Vk~5y1_U5l! z>Y%G*KQfmptU2k^*SAibb5hT!2OC+L268|go}E2t&@5126j|QcsgB|~BVnERZKNK4 zQU}ePbQv`C@kOJ_wPq~kg03QV3{OsW_sBluK9sAfNfqMT{g&oPe4{YB1}`NJV2aQqs&pB22W3K^XXGU+aHJI@<=@@*1BAi zCH!q-5nUHiWpGpxV?SQJotPNG>UbicZ-`b;DM1#jSyv<`op5&YK<8#TBOZkG&l%e{ zQX+iYTbz)&<0sRvKHSu+xNScyXTDD&w2q=OWASaceV=F4T1YazySkGc*zdvz4{ z@r*GIqTf<13CI?5^TvAMXF12NcWh}ZEoUl*`ekr2!+G>H4C&jI#Lvu{a);YK3~L4 zZNqE}M3t11Syr$eoA# zEV3>}d6;L6W0$eXjdbGA`8X|Ckj`A z*N{i8C){e8I(~oFtLp2xl~;4kv>%0<^|iW3cW>d}F+m=6z3l6|k&KTgImSkEaB_Lh z1+nn=#F5;2k*<=;;|#BVAzm0@f2!z4@1e)u=N^F{@YI99`ShdMrJ@N{cQ|c3!u}>l zXzq0V9zYOgXz$xHI0N`hTyvk|?gcjA!rmdDOxre@XqO5@3}5b6An}o~d;b9Suq(e_ zzks1X)SsvIs&u@;RA-+J55!oJf3fYt@-VlV-PrTAdJdWXS)L6$#8KJhbSOm4)lQ+n z`GCgnD3Su__YY%G@0SKZgB!sGDv_chep{{Y9jf4V=NIe*rl z_eb)fMcZ)ZIk@~UstF4<-Ir0db~Y4xe?QKZJTa=fFp}#}^72aUg|~D#&rJ39>0J`_ z{{TMJ#Qra>S1*{kj#A^nn&r%_My;pDxCIyNuHcY!&RlSD_*9GF?OQtzWO|y=+@HdF0xtr}8rjDu?_b(N!jM&_A8B34h z!Rl~$#~Cs~;9Xq?*>$ToUoQu3fkrLQ&7Uy(j>o6Hb?Nv203Y$FNBYL2m!NYta?$uN zSXonC1he7zDG~Wx`ycAN+sx@TCZlATK!_Y6oepVoPzXB*LU70&2W4CbLZltCyTnKd z19Bil(jncWpyUn>Bj-Q}qpVdn#}$U7H3%W8>ye~H4us$kp}f(RKVDoHLU4#s-ss97 zFRlwAI7E;PA#xyOltZ$sn~a7W1R)jr#NKls7jS2a)^sV$Qv>wK?q4* zkPM;Z4k4qG`e|eoI0!;;$UrjKAA&=Ae{*-00XYzIghMWNY^V}hND73Ea>%BR3RXxy zzXu2z;SlmhnmZxH+98mUP7;JzIs`IOj)VIl?hb*BGBR+8l|x*}sIG$V145)7qC!Sd z$5s&;%=WrFq>Bud&KDvJLM)Sd78wWU4MtfblDg{u-VR=i3?UO75;!7lW&%Q-9ReA7 zWZ)2Ihm0XZt#F9BLu| zM=0)ngb*CkK!(w-Ukby&R2xQ<89J1u^qlDqwVhFKv2xM&Qd`4NO^$8NZ#m^%E RqoDu*002ovPDHLkV1nTs5oQ1Y literal 0 HcmV?d00001 diff --git a/data/mllib/images/multi-channel/chr30.4.184.jpg b/data/mllib/images/multi-channel/chr30.4.184.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7068b97deb344269811fd6000dc254d5005c1444 GIT binary patch literal 59472 zcmeFZ2UrwK(=a;AlB0kinU$#IoP$Ufi4r6WOIpH`vw)(4B0;i}a}JU-7?2VL(Sv65i+aSt!UItzJQFhKBfPjBLuNU>0F3RvP&KD=i2f&Xu@h-CGEz?|3lK(Xi z-M+|UU7W!ak8_La^uJG z(~2}8762~#?}C=%JPTC>Pe1fI&jJv2@E8ICY`_h026zE3fCJzI4{ksRyk77 z@W8mpD9ZeaH-@E?y{p4tQ0Vqf2ulRg9O-HY*1?s%bIx7$0te>tOk7>;<;{`iP9`qq zW*`F4^mH))NsRr2_#AUyqG)Fc126w?>koBxvelHms|hys)41ox|0_n>+0y(k7<3a` z7fln(zu@rU=3qT@4;Mve1x@AKruO!>=L|7_q5ndRWo7T=C1s1S{ADSEAIdBILW4Ci z%*-uJTy0%I5X;Tn$>raWYyU$3l^n;^Qrh0u-sw-qCH%qvru^>$unM4Jk36?EmW#cE zs;i5$`JWvJ+ZJ@B|ANIewRZs>>c8M|z%I7>SMrNdb@{xZ3s;3)WSK%c9f(n76prW8*pktyzae2W)LMZB0B5pJ(brWJb$A>)Vm%~#_ zZqhT5XuLJ$<#Y1yk;Y(*xU)I61$X8rO)GuGBonD=Ci^bkMc_)W*0lMzvd8cHTxYiB zn4>bDlt(xB&+cSad>EM9)wXa8h{>vK8Jyn}l-IFz4~)&OY8_hGCj?MXKxt7gWWqp4 zKV+2s} zUAu}!J`V;`Q!R=#3y5dFcdwaH)O2$Zswgf##VZ{0Ol+LlaVt3kl$TT6Y!5?EMqhlm zrz!WQmS>TJy$AyX?)U8)`|H?C$|~yugP;AD_$}slAB>Dk`zVoTXV8$ZGrlc2RkHcA zVrX*)On>%*_-b>m8LU3neG>XkDY+u;-GuADZH^yL>uEq(_X!1xSrR4lH^mR?(=fGgeRS4{To4t+?y8byRkW#>??09-M2eq({xLP}*5Hvu&l!MqC^|mPg~_KT4KfS>r)6L@Llxx6CLUH>L?whl6p!XBBV`5xi*f!$nbqRu|>xF z`}+IX$kqs*Z%YQ;fWXASVdbN>0|Fc0=OLkq_eWAhsxzo0zFKGJ8!=X)> z>SSi+@fGE9O_SNUcU95KE<1B&)Stt&)nx5Kw_?|0I#n!^Yp3?=i< zfV0mlULDCtT9XJsh{$b`qId7<$U;Tm#@9p1RK2tgX}5|9y|>k68a};-oUf~PLe7B4 zhrT&F1`)9G%ADh+y)%GsZ#i#P4@$+hHxp9i{gB=5tFe9l)h_#ru+ygR{E74{ z$+s8>T4@K~w_GLhfagzWQ9iMEmJNMP?tpp*Jh?Y|sC4~gBc<{%GrDGAXY3y3^Ji84 zHhk!B1n;O!jHNbZ`5g;p6O51O#=-HQPYXQY)JF~+9j z^<+p7!uTm<&tR>~)AE>p;!yRwN$~R6?38j9@YVB=euxsE9^~IoJ_=qz zmNtE7R1?W^DF|}nPPWhrwQ*dxN&1u`(4;@oqN{y!nzcW1yfDvZyJBmzZo%bxa}zHf z_kjyaR`t<|c)YFlmPxS5TjHiAkA1{!J3R-Tz;W{3tdGy-vnZ=d>}jmcyYcWa6!uzY zpDZM6`j#ZvNVG;RwoVm2wGoZ2RW0*7sX7CY9EbTMdq@N+e z8lKxPt61=!R@~DZl~hn}N_RFILMU&ik6e+hqPJ|kKTRm294lEkS@QVswtz^l_i5;$ z_Qdka@^vm}DqWb6O+Wwi=&AMCLH3#%dPs{Q{~3@0j+36H#uAN*QO=qIPUpL=gj`rO zgJ;0JevE8^d9bt5K9}y=Ya&5ARFs1=KwiJ$4A>!E7{q?IT*mf*yd5gd;7(Bm8l&;q-(MF88A8M8Ahu6z?DaIXL;%7eW6tPLK$%m zyNbRiy_;i7KJGrLqjK?WJ0*g4r*>f_-CvGtqjh>a^WvRx+|}VDpFVLQ=Y<9jH!?^M zqmjiE#wRE6Ni&PxypF@@B2RV;f#nXpGvHQP)6$V&;>4lx;bTg#A@z65ZeGz3AJz{U zF)yG46N;;4V+*>kdEIj@mtM4o&Z`KYmeQu4+)d-qacjNUUqZ1Z{x3G&S!aClk zlX}RZh%v1n-;cBC5ai#jrP`;)KpYSBoM;{Pb1~DhZz+JUKB!p7pC&#YY9YaURPpJR zxXMYn;a4aO?Suyt2mY}tEfy*_%pX7O-0MsHIDcEcJ-7v~@}wB+BQZL&POs#GFYU+^ ze(j?U^p%9RfT~lxjEcdUx+cNXLIcY|->CT*ugtQR`S9X*f_4U4?hk)`IZ)@JW_hm2nz0B!|Xt9D`Ewrz*a^}6_vj4K? zHCW^2SG)&UqzWZfk+w zDsoC+m+gy@8)IjHzfssjkB&EX!xbUTWH~P()9&08qiW7nHSf?4tIIKvv?o}R`1UTz zJ%u*=FHi2Zm4+t6m)trdE6RneWFLQfdoZe7vu+)uZNs|6-kvpByBf4W>2av~aOc%Q za?`q21Mk3VS~(?IVO;N5^(Z*DD3ms!-#CkhX)Ik92@uJ@OKCC9MhCr3U=x>`^VFtb zMs(^l_Y9~p7c6-Qqv1LmF`RBV0+EHKXbzt<4P9avj^Or1YcuMuAKF*NK|O6 z7$R->Ng^BNVTt)qI}b~IE%o?K^|h4So%L0UUVIAqy3C*2Y2Q$#yAvyY%5ylcqtxD< zO_S9y^t!I!k)g__@(d8%k(AsX@C~-JS#4YisqA~Y(yD|EH=zPOI?RBl5LI+A9N zzmjw}ez-B)-r7p!5oZ6P+yr}yzL#pq<lcu?ozqXO?I<<)T$ zP^(@6uiL|4i;p+D?{p`$5}7LUx%SCj`_@PQ6q3`^f|9LcV~KEpBDm2eS5nq)pU<7fEdA)@%C{*6_;4d&-74?8^=YUL4={avU3Rc+SHX zw>oAp|A?`;B3NR-q7HPNKK%);l7voPM@qEE=+Btt6gG~tuX&c@3)|H6!~%5iH`c3T zpA1G0oIOv^fN_lv85mdA@WxC&QQ%$vDv`36$V7lR84y3xyCItXntIS7gqqLp%NHkL z{44V_Nhe7qWLM_Xz5-F{Et>(J6)b1D$K=Ujt~S}??{lXyFeNg`PV=K)2CiH^`F;lU z!rQ;89+e&^pY`>h0TezQtLcGJZ6sP+pBl#A`^g>P(hl69`}o?c;6`ZyRNvq|ib8I- z{Yj19a$^Ng_P~vr)Ky1ss$M4P_;vfeSl;-cW8WtOMhnPMB_$(1v_c=d(NjDAiaTwJ zufn7(*YMVjsv#TdXFyV0BdZI3Pw|$nrpst(^-6@KHZkvBW%6@uIR5xF@jpXI(`-{UQ)-EkuM{cCcXCQRl1PC4@#AZvF^ zTzRSZ+q9kdI?%8HC+juZP66DVgJw!?x|R*~Mr{kn%M|iyYD8Hrf<;cKqYoz#KHfSv z$Da(!7SJSaIB^NSXcKd~vdS-A|BE=ilBDvQS zR;N`x+WMlHueecFGW9G6x?X>i-e@+Vdbm6y_04LBq|!&b9aZ?`hW`|_xAh+F&;$r1(bE z=J&TY*_Vy3D|dW+6+b-%wR(_nzxf_Ny3B4>r2Cy_JO-mye1S(I2Ww;LM{-5oK3(W^ zS>m(F=mFdvRpz$lCed76c;YF1exEUAHo~-}6$T#mN+nNVEN1z%;mzutthxjqZ>)OW z+1p32O<-$%z);<*f09^f>%^(rbUo1qB{*2As*MuqJjOSVlWW}^wT67?az*9uK?`r_tg`Pt8R*)@59@PuSUvLIE^c{18{DezoJ`xGB4i0Z z=cOAG7#`tHpA`FUs2z%m8)sG-f=>y@pPKgRcUV~^_bhY6R(#6L&VUg=E{}XzcfhACUd{{4Ke<9(Ca>JeIf7?j*`)rVXKnmQ54=|!TB{iw znw9}osy#=PCvl@Y+PtS7_&a_`)dl&Hrp3b`_{~1aG~Ooxjx%e*apEtTVwGFvSs;hrGCVo$?Qy^@Bm(b4c;!Tov2^J zA*ok^>A;eY4o<<0ZglXdxcGWuW6H<^e~NE(;h5p#mt}8BAIH@*ApQ&}Z&CtZBdZQW zcJ_P%EpsK~_x$?(X2exLA0~Hkls4VzE0RRmJ}k2z-fNv-+wqB)oIZI5d-%0sNwcRS z#EPmFTPh`1IP=pORKWiRVbQ$r${7%oQ|Ffat?0|3K=T>E+P8nLZdE@l|40MgG(smS zB2Ms0dHeWIw#Ed%SkQ_5m4Lc&*!yyJniTN)s>?lhO`d@H!F^y}o43i#PV@{wI2KJ* zc1vzb%u!03NqyZrscRT*9BoD%H#~d)3{}47h&rogJKeyS_)-Fn;g^w>oX;G*yT+qh z-D|X8yj4eh2llGJV)PsloaBwxEp6!e?a@k{0os1CdlN^8iaYy~mZy(*nBr+}T>XR^ zB9|QP^cAbiPh)pXx}fb-W$W9|C=1-l^QGnbP4gVB_B%~gD;{USJ6|V5-cx&3#hoho z9ElH-=7)Y6yG`-?$X&nWFNb!gvBxaWVN`0<{5eI9Sq3ij5e0{?pXJg|dC$IF%zIwF znDz5KpRGKf0VD&Y05iZIFa^v37@!7Nfv^jBIA1^xAWjyr1wE*_B(+S)rv~0DuA@A>TNZ}yD&9AkOue}EBkr9m<@*g5ucN_HMcWI zx;Vplz(5_q!^I^i!p$YZ%@5<|7U2~V;R53)SN-FDNsIE>Kk+vl{FLseRG^MenSa7y zxZ(vkDG#=b%EiOO^#b}!3zVoAsK4QyaL5lD2nrMkNBPaOJ^~etAf0yufYKa+_8Z?B zf$@_+Xu}B1pD;)mf%Ovx!$J|*zv=rU0xtMJ0Qroi{E~;8iz@~4lOAj#1?86|Awwz9 z->`c64|>RX*QNiE6Z8Q9$oN(F97g#=h79N*JTp*#$pAITK>LYT0_oBJME^?{kZH!B zbjkRo3+M?M*ngtO{R0M@K2PU1IzMF2uNS?>#V^7w4C4|8`%Z+1|3^>$3l5}31&?!8 z&Q&|Ff9`4L*TkP!@B#q89qRuqaM3n+0zC*EcfV{3j>JDk4z&D_kpr#&^T+|?A3t@2 zHlzIRS4xJTOn=0Se~tk#ZZZl;1DNRO;1oS31_lNe7A7|CWjtIQ99(iDVuH)m6tpzd z6jW3&1`cKz9UDCr)iuHEY+T%Ye0;Rb!s0?aVjR4DJm*XxSYQ|h2bT;FkBsLk)m5Iq zoqohZJb)A^yD}73A*el)%0)}@mD^0Ef#Y-GHI7` z%*1bIaFo6()&QAN$7*-=s$pII5=B4lZMZo1 zi4D1!>-6=H55+5|9lwppKV0f}&nB-d`XGYM(HUn zY(vXI@ce|ReYHnNMQW;0qbc4V>?x{fg6g_h8RQ;&9Xr@K<=a?mP)k0*GTdERsM`MIoHa6#G%>K`$8z){+hD*!b6LlOo~pCiPh{ov z$|tc#_7|VnA}o4uBTfCDi4690t`W9|rmF}c_-d(7Gd}VQD{;+Gd1)ayra}T1jJ!At ztxuUeUZ_SBC{=eU;?+;CO65bVKgKwR>(iF2dtVST*7V#2H(yAltGlF?6Kf52-X}jiC_s*{)RO0q3 z7pugIbzYbD@xMiw@bWz1$$ai<={QMygY_nkF%JW_g2yPn!ZP^qv@#f#$&4mglW)UwOA@3O%tAq=Pd0_rzra^_2icZS7E zR$HgaDe|eFGIjCXc+-|m+E>e@rgeqw`AsG$l;mjch)zk7XVwSGON<3|)H~CXxE{Xj5}%rS=g{ck z?X=_0J8SE*&PTl;KYdYMPYJH0S%5Nct>{kACPayneQsjQI!y8!jE#FNs3Vn_Q8r?( zy*5wE_R;Y(P1A<<_%jp!?Sjzw6zzlP_N!q?(IG>(+_%H?>7E~3VlpQUY1JpLRvw8J zshCxXQYIZccpeT(I&O|Hk21C1WpjLI&|2*`A)VY?eKjPCmTh2Y+KXESf79(a$Wo!P zOwqJfCc=z?pe#c3Malyn-Y!|QPabv(M|O!{Q^kA_nbf}%vF6J0em8Dg%nqrHl##UU z%l7FRmyi4sDw(KE8(Gcuirnvu$zU?u^w*xj{Jyw-GXtW82MRCrZ9M^|m7A%Su4t8y9zeR)oFuCGE2fkvJ*Fu#=Vi9C$cHDdwRrgu7~g^wvxb(< zWv=hSs%u4}N5@~dRgYd0Pp6iYXbN!{*$>j{KI4}RU+nLd1Y20Z0|a;;Se+8p&t3*ux$CJsj_b)~2o9@8#pGH+@>5_!hZn?kmD=`yWxR8TL; ziuBZr*|hPT?$|dYVlThZtbg7kRPj1I{!*|q5xAZ6DW=Myu`jLBk1%L4J1qIQKFD+~ z;=5T77cW^qXqe{9!_4;b_Z=_cCap1%pJ<*ike@^&@W_LlvZa&%4+-W4FD)D7&jPl#3C|6Pp?j@|xmWXje zsp=b-d3?HVTh{jqd*{MKM}(W5uu`b5MEb5)-C5^89K5xxBM`ghR7Uh=c_Cp2H5eYZ zI~}R4Foo=`6_0=COCP3LqRGQNLO<1Tat6>4%!@h?J_{I^ujc>uJ?It*BdNEhc{gf1 z)g?ho@lvmTwEkQTba`ZSgr2o^IaL}@O5#Uvvpyrc^bLRS$!!_^hJfQ{g`OcQ=<%{3 zx}fO8g}w$2ZM{f$CwiWe1AsMnrTr#OO%dAp}ok)_EpQ1V)1E_uP_`y;tkfgaCdXJnYEDtWcx5o`1PJX`8k{v)rew+d4-sw^^t1&wao zc{}u$<(LY+5n6j_c;yNFTg3nsUQfvzl1vQR0$EvIgRVU5;NoGw=R>E=pGj|OzU_eS zJE%(L6sb2kxXaSUwtRMc(;~l1@9P$!e9>G82)X2sXp$T3WrYuyw=bD_e= z9(klyI(J?_K`lCbOas65fJdg4F#|6tp_sN4PmFk_r*I(v?^< zF$(?mn7&KmP04fG7NG<&Cb*_*%}8uPss|#cDD}?-Rd}-YEp6>Y6=}4lcU=3|%|jX$ zIdpZogH7mBPn=hl49uHtt7uGY@2AX=&%*h4$Sl3oH{fjrh0M$9FHsq50~QAa5_85! zIx0t`w`pq?p*yp_B*SYg_Y2+{@qAvksIHG+J$$mw+FE}#jQvSBYH+fV>R@GReS?Ct zozs;3wE}5QYGM9`flMsyEh=)}I@AVi_ZrfNN64P3jxDw!NWQ zQNxkhi^NMomp6RG;KGb|NU@#Jzsz)$*zWYM?d@h|CLdS3-6b7GES#EB(^hH>qo6jyhi`t2?hL!UkIdZ%`adCp`9R4GQKt zSrxxke0-f?E0RH{r>%Fj^Hz4kJyPM*oApN@;u`Ap*uUo1p}143!C#OChk*}+8+P?F zX`x=oVHeMV0^I#vGQSxzO<1rup-~vFjbpY!1`%KxlHMNdA}+;%6c>%FH}vdtv_WR{ zWMq@RdS9SJ!M^G=D$4dInzNtEO(OkR)-}%f8GYTyaTb~HAEOuptM_s~sB6KysaLhI zgW~7BX%N%G5dy^B48! z!q8-sC4D1WHD(i*Tm#}z9&gjCTytW~`|{eF-B7vIo2hJ?JIL!U!&(Ke;GKR#ERSbX zX6PjihA7=`FI8p>7WmYf#S&dB3)x&DpK+wVOR)JtIZfIT=Ff%L z2%N`mttpreOC$0;2R~?;wVACv8Vc*JilSxLNqkS$8|l1X>u*p?jE+M#?iM@?AFQaC znEW1rYj!_a_>f}o$w|5R?ff$U-{?s(B`dGqT9~+${gRRKv?0FXO-jpQw$uT$r%~fEq1Br=^ z6}676?%pYI97vcmFKMJnrmN$&*!&b+8yws7Ok=X--e|j;D#jc$*-UTTq0H9t=5u$l zt1nO+$#i2Ty%sxa-fBnB7v_?%RJ$L+b?4jU;Z`gED|@s3NQE z@O4x@r_k1=a$NUC0^aaZ1vb4N{cFGQhbAu^GLl{$VlPC0{5+rTzi#KIXjeS-FtIYE zImhucis*brghFT`Zlv6sJhXQX=}_`J6Up&ZuKDrWrF4S&i4zZz^rsyW{cW^n+4iM= zSjUdZEw!)S;FEZVilg0Mr;}`m^yxN@+$S`;W;o0i_~ak}4YArRHtF_323g6Wsvi=5 zfAZ0T6n|FL_w^qBfHyjD2E0+W4?Jb#9D7%_Ik(6=jjThYj_UKbzu?O|y(`n@7M`3v zPO>maQ9sg&aH!n+V85`Y&7tT7h;z3Z4CU%Yk7aZjuq)GFb{Ts8vWRcg9pF>a66cHq-dE zQoW4QC~RJ)%lM*pv!~)uQ zubCSu1=~qZ>0KGzDC zo!mE9+h1&~Ui*ZIkCHvjtwYW`a#4Z(0DIzE(6S;tE^bM7L>%S0o#C!<7{&6X(Io=Y zr+VHWuG0_zxMs<|9`0^celVe(p0E* ze)6U8MJzg7tX6uI0yHV48oqoJQ+=+Ryr;(E4Lb&1$Q&6}_?uS#X6r^mV#L-mX<;lk zs}i|b2{c0dB)kg|kjZ9w(pT)RDf>j+%Xc4#9lZ2l^Z`vX$yW1bm~H?M!8d5)G->Cb2Q_69UlygplO!Oh1@=V_*os&Z(?g?@X+f#!5Owl?>^P)hYZ2o#?v$nD@yO@{HB;oN8Ftb=ucV-Uw7}9?D7;J3wvET zWmw!?9#O7-Nj%pPeK>gr95zvGC(@B9*=Q0z&CtypzdRF-#OOnPnB3LC@@c!ei9Q}Z zmd!Y~_&!O!etxZhIxeff4?XR)AB$N}!CQ^nB*J9$JH#eQ~u!0T!@_N`9|S!G$uxhv=%nw~xKu}|Je z7z23=LSw#2-WH5X_wvH)E=QJzZ^}#gf}6cIG3G*9vx)V%GI9z#t@@C@196ynl+5Id zTp|Nebh6h{9t9Jx8Q^Ha3)kRLTLSs)c?O8B8-?_FPM^Y}t1PMDDZp%4BP-9IvUZ1* z2hBH5&W@3rZJBM8Xf3?SMxOHx1rjvagi#fBHS+_x8Y;KMaqY`|0tSh`8NGUGuVrSK zUV@8*FM%9?6ts5C?R_Nytz3}U;FiB4H|YZV1mH%cHJ!>5Uat5)6YHqe;8}iG;oRbW zr&rMqChxATW;h~lXL(|a^;;c!j0-vK4aFj|$?X2MXfO3;(Icqt8yc;=t#)4GnXF6=$ z8*QFEQ?!q?XT%~W7Z0(=<-7io-JLsze&pt@ulTJWhEG4@ssN4#V%jYGn_tnzUNd4% z5;^vLJI-6lSbjC7ueom4xtzp4IAcIuOCP**GTi-*FUunxwVOo_!`(qL(X*mxS&_cS z|9;R(c&rC(fZ++O{@KE_QWB??Zdxu>$B7~ljix6RuntQ7@%DLccdnZ#ObSz~$iNzj z=z&13XbCOinDVRLp0`?_G(B-$9L}K7RBt9b?){qEegC*dSYj7k&V=ck=hV9_@VbS7 zwvJ6+)j`}k#%Odto^{1#Cc~Be?WDY2){Jhg{oxdaoMfc-V_%G3xLqk_&8xsJvl`#T z#wC(ITt=Z4I2Js!=dMCMZE9B9bM!$5Hb_>N0_KPklv+0^{q_(82ND z8792fIoR)`nS&i}6g`(?l@8Ui6zFokySbCZoz&%0f2qSX%-$j5&6G@nrImPCk*buA z3oRg>w(MY(mB|nxrwA(&M9F{`t2lqxa;>f@FBt|HV3yWGgs*+HpU zUJ@!=qwLYi7W_gT_G~h4;g(fOlq@~hdmdz~MD+XCsX}!N3z69ZB_tY*kmjpHOL-y( z;_K`8gCRn?`QaZX?O2o!7V2x`_9oZ1A9X2qOVy#$yjsyG^zc(C`8LumTY(_2v5<nG;Nu!s;yMJ~lBC?b#BGuPnbN>66Kp=XbKVac)PG$PCD zXzx|h?(L|a_GNGu#^9ZR(^qqi!v_)*%QU#Ly1=L|7PuD>3>BNeogI|#$Z4KO-4tb|Va_1Y&j{Sl z%iiEc{1ND5S5Sk&{#y1w<_O^qPA*`uOAVytH8XdHgYaVzwsm)LILB{-@MTk*a~S13 z8hF_W6cB{T&SA?RFxv&q5BSD8Y-Wcv18L4>b}+LuJBOP=_^F%g`PMt=Z4eG{LzugR za6bq$*t*&wK=^#O=4Cr`6L7%+>iJ$R1{ZU<)p^hxgmIlTHKajU3|uIHYxyVG^iQx0 zxCaiT1*GjAJTG>I!LGrt!?=ZogR;HE%VCz!K?3EUjU{;y8_A6NX7Sbvg(O~c&6 z+{qjXHl+=E83Jhub~n-t;exP7!Vt*+u804_YJZa99Q@(eAV8dP3S2hk04{Z32B2%5 z04gyK05wPgOCZ1d?KYM+xBv$LbZ8fT_&o@NPV0ptQjKq*iK)B|sU7N7(83=9CHz$7pStN@$99&iN4moXrC5Ml^9 zgciaGVTJHOgdh?SS;%dO21FNf4`K;HLR=yDA&(%>ArX)`NGc>3QUa-gyoIzwdLUmQ zQ;;RdH^?Ch6a^RMG72>c6ACAa5Q-Fv5{f2@A&Mo6BZ?QwW0X*omndl{1t^s$Z&5l> z22dtZR#0}qr3*MvQs`AEJ5&fN4ONBeK`o$8P+w>eGzOXmErQlUTcLf>N$4u{02LjT z7!`)fjw*_(fU1QGM|DK?MGZ!cL(M_0L~TaxL!Cn1L_I~rL!&}tMH4|&M7xV-iRO+L zfEI(6g;s&~0c`+n4s8z|9i0@N8C?)v0bLi}8r>T`7(Efa82t@;5Bd!HF1W&h9D@}@ z977eu1j7X*0OJ)#0Y(!>FUB0kcT5~iT1*~HIZQoFB&I)R3}zl?BW5q=0_G9cB`iiP z5iC_KGb~T6Fsux$TC8rYIjkdWLTqMiacoU&YwU;EFR_cTKVXkxZ{y(L(BTNxvtSn~nPh_Y3Ye9v&VO-VHoGJQuuByj;BZc;k5A z@rm&{@D=gR@cr=P@hkCr@z)8k2p9>Z2#g542x17z2)YSYFJWC`x+HzcL97ZlnZbKeUUPbNHUWe{a4

    oD6$_OHl{Zy7RVURhH4U{qwFC7l>i5*EG^8{)Xsl_XXc}pjXo+bh zXc4qgv`w_jFcO#)%og?%)(qPOmy;-5b-kK;wd?8;9W$L4oj+YM-B)^SdLeoX`Y8H$ z^qUN{49X1m8S)uM7_k_I7_As%8Cx0mnV6Y$m;#xqm=>AIm=&45nDd#(uHjviyykc< z?OOkJ^y@;`ZLTL?|IC8I!p~yG63^1jioz6J|rQrLhgM&IKoyT!-KXU> zS<25fvI()4YX+fKJ@l~I&$ zDMu=gsW7S_RjO10RYlcE)vs#IYL04k>geif>hbFH8r&N9HCpcw-!ZJt zsMa+tSFLy21lk7Lx!U_WGCC1DlXp4q`rPf%rO-v_R_kHtY3pU_?dr?wN9)fS@EZge z3>q>UdKk7FQ5e}8H5d~Zn;4hfL%XMQFYn&5iKMydq{Il_M*oU{OI) zJJH6`?J;~Y$+6h6uCbFZZ@;X5#qjFItK&GUxPkba@x=*L3C|Mt6U`EPlWrsxB~vFq zPd-dRqztFZr&gvhrNyM9r@N)kXJ}`%W(sBIW>I84&pOFQW>4m5)>yeNg())*{#PzE!HVsZG4CzFnlf=A+QZsty5g?=ydA#V7tx6dmED5f z)t`kw*Y=3@H1Un~zzbUC3SJS*%}@Tl&1L zzdX0%uyV8-utvC+uzqd5eB;JO=jPqbxhsZ zX3UZjc-Uf+7rfX(KI|3TZNe8BX4H>zgEWw;O-Jc#}-FZ3eoh&)Qhk9X7E*?%E9(Is|-PsfAV&cJ$bfyPeybuL( z_jk3w*aUI+!Xy_>{b)-DRLKe9aPf3#0)F$vMV#&z6Nf)F{ueG_0mz1z^9LI&v8(Lt zFc2pq?PPA^V($b)IKxbA5hl*&W-xa4bFHxcpt<8}YHbd8`8PCToaeg#NuGCvHvrXK(F}s!u;~cUv>lYzig%n`i8`R zyO!DlG^!2)X=d;4%mxF$N#cC|EDtWQ7CU$Hzp(hj*ykmGw%h-M+9@mlul2LD`|TQ_ z6w1mXGWKxS^Ht%BG7>*Vh?zYcVJ7msb3}yrxOq)@O!(LZO?kQ4xw*~Y?80XJX6&XG zyyoU`Zb4yQK8p)w|H|xK+yAVO%)oK{b6h!qU!u7fRc3Hba9_^F>UGdtxPR+l1{X2= zW0d^x=O2Sd%GTv)_D9|paX$V;Y)z1s;&dMDKR+b%v%7zZYK!>A;=uDDi3FYIXWi6shIcniSR3+EQ^0b%~)0MU5lBRE6PgZpLopw0>Rl0@AK!IEu=XllJ~LdlD|{F{Mj zbbL&brLN_4aiHMiG!Koh4BOZwB2J-l-ZNA)#a;~kr_`;k6v>z)8&Q*lrh&<=70#6s zlmQ>-<|39d!UM1n=zU{uIpv9%Cz)n8xfHi+u$y>H8hoQ$b$E(h-A~*FPT#+4&`cih z%AFzh_V*K*rtVDA-e-2ooqfHJ(V*#!s}nm$)V^swHZ;_^GV!4i zFRsV8byA#2ZDIB+(ftY4*tn$b4@T6w8mwWcRt}>xr@Ey{J}8s;^EKk=<0CD2%e`Mu z62fWjiV8NVuqJM8Jm=)o&yh{6iC1$w1JrtFaE50eEu*vEdiil@i2q8yzUM@h)8>H$ ze#RJbXWKGYd72jW8GY!?)M-*?6^SSN)Jh9?cmhQQBLBYR3s%2(S27DUP;^5c8*>UR zgtB(;7E^NsP6_WRwu;qS={jG`e7$RQ$$QNei zXRVtgdSpoVsjw@$`~_i#lVLl(uo0Jtb1TX1A=<(X*WQU)_sMF#P@c4G#gYAi_QE=j zrtr07YQM>HKOA4IWAiPHN2`ql@AJm2oKf{Ff=K*VAu(??O{*uJL>}EGo!B%k4)`|O zN_ihALbQifhK`9*Xmmf|DlB)mt+?YmzoGsIBaZsPjDgj4&e_BWlCQ~Pk=p_aXr@x&ykd?V5$Bm@+W^UjW2FJHJUXy;u{BbH+U@(}pT`Q*-IIxAgx2hCG_N zEeO@ZDZVZ)Na5;bp9nM)OQA@gpEp z(Nsn=!5rjuuax{pYh$U&b!#P~Tf=>9vxv(4s#ZTSU;g<3WBJ#J+}qtYy$!&TKFbVE z8McnyLW7_6g|nRSG0r;I%wlsKbvmh4ZAZST?X}VRrvz|bQIEpU>lC@FMXhe^)xB4# z`Vlw$7Moqa-S!`ae+TSX1a&Q>`H{%D%YP6%v7QD*EgHGKT4tD-a!Zlc82xJK<`vvj zl%H)|H;)2jaf|-|Wv74eqv^dz{tBP*cFa7m>fRZ=KQKj`!uoJXn3KUCUufjDzIF)9akM|{X};T1xQ$7Y6!Rk` zfY09GX9Lr%Y+ZP3PnE6z0JXKYSo5_YfbU>A>&I`?6<=JUoTWOCPM3W@Q}bGUGl{O_ z3RE7tUful9b@0FZ71!gwk!xpvsV{@=^!s%aNd*2KwMb=*j4CiiR@_g_GEOl`;{O2n zDUZdE5$b5~BN7CdH_e40g8u-Ar+n9mm}Avx{{XXJPd)uVQ}%8O z@cm$~h|H+MNu_I~?Y;L;$oGv)_Wk&^ukDiO$J)VrbZ;%85JIXCJeW{9JP(+Yj1NkW zNBzD2E$L8sQT$5r7nd6eho2;O4hi(!fMeL_^sfQDnnj4RaRhM!!P`4??d?l%s7Tmm zk~40qHx(H?`ubOoPZd?hQmE|it!?ev=jamQtffUKhllMm2Z7Tx=zrj%KN{WkuaC9m4uRQLI6h<79Ii%vE5k+GNiG$a$=X0E z*bX}PsWkb*q>#?OWINTNcWur{&&$R*?^?Gh+D>mxZL0qOf%lUx8_KP3tHS>P;ZO8E z_8<5u55_??ox^zhR@lIAg7iY`CQfh;&CACa&mF6t@o(+N@l(dPlYOVe+KVY&%?V(* zT%465M^3}*?T;aQL`POnEo6jb621Qbh;{UwnNH~?0?h1Qz)0*LQ z0p>I^MQsE(4fn#y)O9@m72JN-2JY5fKI$3m z(bhq-W>JEH&j9g^_O5@z)9YR`^5oR6AfHcZ_6Z}C&GJD0B6?u*0qwT4ol11$%LjMp zdeFdBl2=_dJ(ot)%SRn{ug*k4NZJCeSNB5&nSZ0 z)=?>y7=V(NJ+O0uk_B@@ym>jNxgXiJ9Mw2~UZ-E<--{66+(o)MBZfnXrbt=*ul9-K z0CyzUn`qy(M~1v*tLi#shOwt#Txsr2%t)bH2-hU%*y9{}*9-exe%N=P0QAYcF)fV# zCe{2+r<;wpGAv_F-lzm>fOg%J+~c-u=0AtO6#Nh1pNDa2_Ih@--aPSRi)~{52?0>d z`0~+5&9?x1jGE@DoooB3*{zp%{ztck;qPfp&b|3w&$936eTDGH_Qmkui}g0K@akP^ zH`ez}1d%{mKPDGa7I0gth8;Tl8v1L%zXx>PPFC?YuM*vPapfVolWZgoa!00roq6Bv zTl*O4z60>?m34om+JA2gF$}j6Mk8sW3|cjCKZs|O>MNW70D_AC&{~$2eWiRO@IA!V zy56i3O?u@RV@lIVA%;dsk;%&A1OcAHy-sSFTpY1jhpy>%eUF`~hmJQMuow+H%gJv2 z4kz}6{h((3q5dZ?hWrzBx^IL$FLWc+uWi_ae{~zK=+}@6j1ivv^{A2Aj5&wxB7rhGK;rIwXzAC(@S zsUMK%cI+q~VLjNMa(T(GB>lF&HR?;O_Mglr+cIAvkG`T9tJdHA9f1K^=R%j`3}6WO(E(tW|KM_;Jv& z9Gd-Jlod#OM)lhMXTZwySC++y`~j&)uG`z(642dV`AsXwl}PeQ8By4Q(zEdI z$N@rw&wBAM+n>aqH277m-Zq!;x5dWhJAo#tsE3N;)uSt&$gmb{0Nm~igN`ejo?$<2 z3R9BneNt(3dcNmfI&Jd3ZMD+B&5Psz01v~bK{dt3sMhnFR{J!wOXZ37s*S3Dyu*M8 zbBgk7;*D*xN#+jFuq9gr9P~e0^KE4$k&N2x0 zuA4#lkD_>{HE6COzJ--s1ds-Ay^aa>2Dow9JXL6^PBM+PcYF2Ir~CuTLO6!gNsGTG`#T#EQ_zB)It#3_5^0>CP%UhD(K; z%#vnBEV3z#7Va<)(z%&#S#B=3Ni5!G@ZpZ_`%lZyQPQ#B;$_X$D;leLK!YMLBsbLe z{S4Y&Yy z3&9+B>sPeM65qq}NpSv5toYcZZeR}pe_GSgwFo?qI!(7wupokPBVpW%;l)*&ai;fn zZLh-j2+@r=@>AQ$%xftvEthP|iH_xdpEs*|bp0v1ov+#S_ua7;R@z)SW(|Rn{_hQn zn&MTG-ZoWbvtxoAu0PN8rs^6bmgyn4k~xHjks#a@U>uKrMRDQla@10a(%S2#-~2J% z7|K-nB4e9>wEppBMo3p=m?6vOpcv`t)YYpEC9bbh7>At91_&4fudf-Z5opsvxt3wM zj_2$dKFJEtFt)X_)gav)`T-jfsK>b4DD`x|*1~lOJgE=Z__2Wcgp8yib4rwQW*`dDBvMZFEg&j4<^*=NYc1uZT3b?ci6A zF|?9h20G)eNc8MJl-nD*e7IWTRY*_B7Cb2`dJmWJ{HvGoFN$s=pX{xCvSSi1PS6fH z_0M|R(JZZJf3!;{nH&Ra@W#MbIRYrln|8c3J^673z7spWyZi-YZ7 z(JCUc^Ixd&l{E|`lIrgLPvw6cEc03tU&;Rfw==^30Bp-ow&)gd$7}Yf+11e>{ZkZe z*lPK=z`qe^@rH@wn>}vVPt)}WpX~6ekRzRQg#h$yj@hrEzu=$}Ua!O37FNEWYMM>il_QbNR?yjPVQkXx$55Cd!~7yK}P6J2>8 z9+aQAo9buA8MH81Sm7|%_m#a7&-{A$*QEZ+-vl*36KWB?-lV_i+HJC*-d$XK1xNuO5S@YEZ}{mqYR)G07C&_PxxRSIHPbz|S9l`0wIIi|{kU zw;I=qE|%Xw(&G}abN9{(+73GO!0%s|`oF{t4#(mA-|(D8bsJqH-($)|nWwsu=EQ&z zxJJ*nle-ID`~em1<&1FAdy(&U)$Q^;jxVY1s&JQ^@E*@yRtAmcg5 zwR`^n>}&Buz&{^>F-f+?8c}fu`2vw133VKRbBxzF`!ReX zUk&_cn@jOzntamUX^(%S-r7p*46<81LhcesA>;`>Po6LVufEPB^>r$fT0>iRK0_;; zs|eG+t=sWE5B5~}J*;?2{%eUWEVR?8%4T5)%-?rlVnzTgxFio@UqP>fV$`*Xqib}! z0SL($^RxUQcjKIURqYJe%WoS-vB~EkSll*tfyl}4!4HYwH*vdkM8 z*}gIk(}A9A=f8=6@NbvHy;fNtMDZ7hJRfqZS~%dB?Qp5V+6ul73Gcs;di@5xxtb|f zN#RMPjB-HagTTkX{{UKbzT9%t~@UhW_WowF_48SB-QZ5OHI3^`H14--C6Ha$h6H{yMa>h!uuCw(>UONAVB>qWc0) z2TJBOfA}HKiLD9$0EBzvMxk#sE+dcoM#+{kdl1o*Lg%;{$4cI&ZNvD58Z@zPHt!eu zXnhS1EzDxMr-tTRT}ggS^O4~Q?k}*j{h<{|fznkQvB~4wouYVN2=xT}WOo_mRFV!r zZ2p-1D}M9-30?a}XsXMnc;~}pe`-w|Rgo-w zF{WzJr~xLlzk*$xzV_U@5((oN&3aU8VQMej;^6JxO?P zFvo8*G-qO$eaxerl1CqprD$pTn%WkQMg7vlhIa%ko`7TU8LvFlJTvj<;ka0J&x5`w zV+?XUuMVFy{V{?JVSrrcBtdy9a*JE(D8JO+hZ$w@c9fi7W|{5t_>S7kM@4A{ghg-z zW78dZ?s=?#vT>>ve{$^a({-WhVfdpfl+>Mg-`3xm*m#e`jcKDvJTgNkost*LV zbsL#&;F8pNeF3{!74d?9h~uqu)XSYYMqbIQbhnZ8m`rsUMSD7tY4&UU&eOwEHof8a zR}yZE?;L_bCjgxBnw!Ko-)5Q%SmlvbpUQB+3@{n(&|;F}&eQu2cto>9kc*Z64yT45 z{p&M8zte4PCDZRL$&zH|=5{Q~6P?_0N$JxS9?>b;^xx2zMYzgM+}iM+<*X1G+Q#DP zTHyJT&4IOX&OJqR+GJK*GPF>_(Zt7l$UbF2bB-&Kwb!8*wyhjL@$x=YjFJ^QZYQVs zR`r(KMDrtsZlmb?nC212Hj#|}Ij(xqrBhhDw|DX~<#AVgf06Bfv5nMPf9*@~r6!St zlK9s~K+urG{y$>lC!Bs2{SvBg=Z>UwujdE$GL|L%slFEOkzfl>6qi0cWs;l@{_rbAos7TKj5lYGDZ76Myq@F zA2#|SxXK|C<1OWu?5%XcHe!A;5ss^m883~x6(XZ-XHM|?y(F~UT8PAvPPHc=+AR7;{|h)yNMijA4|MIwJDtEkVhjJ^fk$a zr^)WOw@*ZUU1@Uruvc#1p}_vpdKCIri>Je-&v&U`THm#@>0Gw`&4hPsbF|=N(APKn zLS1N@=8*@2^yn@l@fM${*`?mgTtZ<1e8YfF=3-kvO5W4_P2(Q}>)Q8=u49tiXm_b* zVV)9h?u_b;SYVJB44`q?)gOj3>Ru=Kf#a_pc`0Xoq3U++ZrAl1yNcCOMw|`{qD%Kqa@$}hk7fJhD$7mslAxP7M8=M2zo@%%4 z@B2+j;lBe|_#eXeQ2zkJS>ncPaka44b|V6L$9{(ZW0DU}E3Ehv;N|#3@MrCwCb(GT zjv2LWLRG>kc2gq`-*hS9@@wNyh8{8h0EPMSJ6ia$sarLr$B2;HXx7rkH(A*(?Ysg& zEt1E%+y`M^&VQLx%CLUPd#+k}-2CS)!NW1lDq`wuPMc53zKHDfj}F=RXG*ivZFe%m zle=I-j51bNfzwak2fL{310?N>q~O@ftldPZJ1ayp(*zZ#f9KQ<)p(Ui{!6 zUYZdlvZ?*#sR1Tr9mi5fOm+HK=1=YS;;A(++SA7n_=-sFZTv5z>2TRZ@{*roxQ0n* znNiLdnN`72*J#gd{0|Fc)bmWw*i=_`a(DIpPkV@S97Yy|t5ekB{w;pdUNHEJ<1Ie( z#fyD;;K)tH_7QpQZElj`WfC}Wo_UE`QCXCaL)4Ew_^;xv3&%RUzN6v~3+fs)&Aoi9 z^4ygtc?LeYAShPnIKZ!|zh^J_C%2BjXb**&XT%1#@ZZC|CscK~(^Fi~r`0v4(d7lL zZdK%vDOpAUO{6Yc@cFaH9|Zh;;Sbtp#GVuJ1($`iT_?k;@j>_TYTj&8`8ILhn9<@mVv5eu zz!jSd!RNJnZFS-OQ^EfL88xfTa>8q0?Axf|X(NqeiO@%Y%kJcE!RIINtp5PoFZP<$ zwMgv$0JH3GZ1lZ0@$KP~WC}GX5C+>S=MpNG2WbZ%FSU9YtiD*)?4@TJrkmgT)Zn9v zmKL`>m7HDI$o<0jVe$7(@n!YSlQqVgH7gLSd||ptqqlh(%utCJhIRm+K3-2X=^AbK z+9Ottm&#W$tNH=R$K#s(D*cTA0BcBo6wi08*y`4LRn*eQ4xyyYxuCl9CPOv+0ALAX zkL6f+Rt1<~^Ixq$vhTJ2^I{<5#zF@3l!7=B{EtfdM4^dyw z+*!dD@>=eEs#KoMFYCGc9!JD<@s!e>z3#`+8fDaC>30?p;IGO*h?T(~l}C7Ex4dbi zM)~=dAKvvflcQ^qTV8#VStJZ_xhH{;GwoZJK2^r}yx7`Ka$i4y;=U4{YKz{bA9Y{v zK91`A$fNc@v?raPHU-RNr%~FVpTW?}rs^71-IO}Tjr2>p3BX{-JAotgtrfX0 zvnI`iVOjl0(wlo1+qELaq=rafQ1~Qg9Zf&3N~iB-qp>b2rx(=8(;~FF(bnG8W3@|g z$>qZq#fWs-f!KEUtlO(~(`|%As}1VJM;?Z#~=P7%EC*1pfdo zJ?hP$hCjBX-xPDHjf!$X$2eYd>5A3wpj`}si5*r0^z$d@uM6h(H9%pGQ z*+0zZi>V1G1-`FS&(5R$lO&5PMf;f0ESbY^4Z!~Zcbw*=)-+8*Pm#jcmATg(eQ zcW@-ENm0NIq!k$%`qs3%Tr=ud4I4BO%t0y!Cpj%6Q5V*RtFqU7C#!t*L4q1Ei&Qw>@;W3%68A;H@REpPWH?_Om<(zQU z-Ks@DRFBM$AAi9N{3kTy`y<7^BD0P*{{UAsQB3iI*uh|{y9d4nX4?M%!31|)z?Od$ zH7FyvKQNa>4Im!569JE|EA<}ZM$_$Pf@vlS*HZaEy$UxH0So9kJv!o{O)4>QAU7UN zG4A96fsWblUY#C1;uz7l)-NaBC;A+fGJLa{a#FMMZGNZbeyjfg1Qqz1;p=(sd_VD0 z8_6L}z)b|U%gG=Ntd0l1RFAD>!T$gRxBaKQCv6(oYyJ_|A&o&lX}=Kt-Y~vod0;(w z`d8@>hdf(8lYMxzPHwF3BvS)EGad;ARDF5&&2~_D!aJthAd6~^z!8;0)3t5?01+cM zhr3cSZ@()24vaoq8WQHEDMk57pOT&~{hWVlZwB5+4BiE^)VK@gSkG?pGWN;Jo=-K0 zdwucG!FMp*>t6|cLK0)NJ*Jy9yN`S{YQX*mzh19rU4GtHjtJa_!CAR74E^l(=eMP0 z_-{?cvA9^a+(s}E6Y`Aq^y^yF@fxQs>){~!H>y1;Vliq~jv{|OkIsMCo8!NUyhr<7 zd?dBM_<5_^y~mGqC6T_%Dy^b6p%5>UfNfq0B=lf;uj&Yv2aI(!`3I}~Rgc5}03QAn z=x|$!ueGlh>BizXoz%k@+Tb}P9PI=rUOx)`elhZYsIS_%8#0zN1uCwRq_3jA?z^9d zc$bc^UmHu?Po2dz_@68P0Kr;59@}{P_Ia~ew->V2Ba$GlS1x0986U(xVs`si=XJ~$ z)}9v9;k#R?WRh~opca-5ndGi|5uDfP-~1I$+H>#%&f?vjWVO{JU=^?qw z9!qwS5UhDPz?6)Z$5K9(*E8DuIMuY>Ev#%-lHh^((urQ%AnIxxctKtX??Ta5kvy6mI|yyN{r*ctuW%=6b8^-ECvpt%z$8DaCuP z%HOHQ{8RBhxo6wyzz#$J&ott;}0bAMid5-Ist*S=eQj^SG9h?{{RZLZ70Te z-VD~QXSd&VW3SxY-MNvnfH+(peqzhVVk^Ra8Ee;H6~AiD3t7@`8q#YLVO741$&D?N zCX8iR^!ZhQJlEb|uw9z#T33lR$*&{TFST7u?5lY3?3kGU0CE#2w&dUrFh|zAGF%*I zS`}(1?&|$I92`ZDcw7YOMJq{O{cczP0D_!+TYVGZC6|CS`JHe4eKCbv2gH%t#>%ot z4^g?e{7rbqk@i0c=rGu;4<gJT)-g26^7n$v z=}=!3z3Fdzy_57I@yEtXPluiZn^x36b8BZ4lLUl$*9_SU9Al$1WE>BA{L=W>;=OC) zAI6^)_@l*^z8`CU7+c#Vvxs7rYkLg&am0>B)pTsQJRFnLx7B|hHCeymr*uyb>$bX* zrmJ&lZQ(oEt*5xsH0z0xqgk&WGyAqTK3a{#7;juxTk(7LEV#GuZ}yjn^!*v`H5lT) zyO&LD+uO?ue*KYO?n9I~PNzI|&3#`D=`+5~!hI*q2Sd#qWHvqsxdZP}VQ(byphxK#w; za%<-8H&WF6X{t@AUwDG@7x8t2#dRjCtefs7+C0M-_wu<1_;+z%W?guFJb$A@r|JF) z)ZfD@>!@4n4|@PuZA?z!{hZtg#D5Kq#l`i$qpj$I z2(^}38aWvw#8K~4%QJ9tItub%6u^ReJ5{%Z#jGM-k-IwO(mlf@^v^$sdio<=_-}RM z&k^5vfLO;gEo{PgWkPosk83+G_mx-w03FSFXNiL^+vxW8vD#fp{gk&&7Ug18 zDPVGVUIu#{SI}T_k&M2so~a&HS~7pVo~!XZr%=>k(e*txNLJ(RT6jYfFU+#AeBk5` zql){7_9XqV;qZ5jWzsJ#f8t-_q=xs(k9>@}_9KWy(y#heh=e`;AgW<`ccU(1R>qzRL0 z1!e)52I@DTTGKVk;hA9F8mn-@vH)JFPs)B7)5nh0wa- z7gBw8mgA|(9YL=~wz;~JC7BQI=8&{xA?Oz*pVGe${{UuR+ULYtPmHZ}T?baxd|RR2 zNG>GO?xaZWZ6J;qqT3Ue%%x&iSmVg&2dS^ruh`4>qOkb0rpB%fw|5oxcp zYl|Gy#$j*-auwhMP-6O*qO=?AdzXx88j$uH=$dcl-?q)g`u)H;3hmDza|_1n>#t z-!-46>;6rgmbV+@nF&;KcOf9-Irqg$9m?x=(perxe*V$aI?)nmw$MgS;um8{hmZoV)QHF;BUfWGp}f%loAC=f7I>t3gzD>QrK) zw%5BaU21kLtkW2-7BvO-ow?3<{3{bt)O77qC63Kywq|uMN_@O`-Pe;=^+;{?SR_?T zjj}iH$xWniGmmPMNAQey_UN%Nh2uV8FT*W&(ZIMp%MP4Zt}g)Iv0bvrLQ zIhTAuGk^{ooO@SOV|#0H6k=&NNW2W352Z81z7|%59g6#MrB!$U9`%0X!tt!CvMNTo z$aEhtJ$U!3bYtxOv=_bGV+cvx-`v6RCx>L7A#racM{AIFt}r^0`1H@Eb1_}d3=bu| zjU~;A%kMZv?Tmgk+iH4)3rU%QL;MQdhQ}W;A5Yf2x8gRhsc3#5)wL^4GHc@uW^c1V zcF5sKEC;E^c&=$wttqNejlEgDIHmi^Y+StYt;dQq`K_AXIoVlV%C0tGgP*DC#YcB) z*BVGMX|e_w1jful{YQGm_&4$I!haET=`K7!GhJWhOYkrhxNZxqK!z&DMM6pu;>>Wz=qZvS#9H40UMMI z40Fam&MT+Wd@FUQ*tBbEt0Sj7&g@{-Zy5Yi9s<+kzSQKkvD0l;ou{4FG87OA&V9$d zeC6>Q{t9hz1ghQ_wUb)=LL*0UBw`q31dzx9QKY1IZa9;|1AAREA+QUHcHT9%0Tz|qzWYI+gzGQIQMLd31 z1_shpWc^2_e$X)9e)cQ!x8Y5TcxU35hhU1|mDKGKLt$OF%@_pwfye3TU#c57j=A?e z>-R4Uan&(6YVfGO86=mLkI4K-!@`axmnT%)`_C2qy)7(X;1%@O=?szET)`wlLGp*) z!5-qjIdr?#@ZG=sBn4kp)BHY?-Cr1^4J620{aE!Qr>;$Yi2ne=MdgOy_FvcJxLwi0 zV3nAS+ZAR)#17|ZJuBglhnjAW@gKtyC9S=_oo2Tt+6a}tOMSN@vi?!dEU-$>n_W@GnD3=-Xm`q!T-kfVol zr+5DV57W8zF^h7&td@@VwuIAoX4~S7-|&({rL~Rn_=R+d~g|yGLL-M!GE(S#63Ua&&6vW4CwdP zHkSG>uRf^gleCR8HM+!r7RK+Go=>l8_uu#|hs3Y=R9~|K@sAQ_&};x*iU7Ibw4P7)H~A)!SV!uHIKXbM}7G zAy0+6SZkQBFSQ9C;`FKW43jy`h)?&F@m+_9=Xdc$mUc2U(A$#io3}DAp&ds(E9d)9 zh?)n$4~||g@go_^zBmDe2Wdm-LlGnO%edFx2pifLFT)Eg!fvl zx9v6JPZU^QS?XR2@x_Ws;oGQR_kue&5x(J{zQLMKnA&Ue{Og8s^e9xu)L!FVwe)AL z9xTG*@ot=SmHYJQciuSoQ)lA62K!CcQd`Sg=)&Ab5ZfUt+m+Ri-C)=t=Yjzo*NuEm z@Xy69AK`wXrTBxynzxO-A95}H%e_9&%ha@(#nOlSNf`S`WaIZ?P6!8~HS~tFuG{ME z_Kz}6+X6bF;1%H2Zx8CW`c%ZB*y`DDtVYi2HI}0jj zo)?bf*M|Ia{gyTF*+0iR9ruZKD~}cUcU;s>o{gyeo9z0OI-H4!VhXtrziW?_b|i33 zexi8eYhCGjtbb=!dr0l)ML#g>l1JClzG43Wf{tn)Z^0i0-$`?03uwO)v@J;|y9H5R z8RP|tUzjSVnVcNsIOe@9KNnY*62k_P_A_aDy)}LBe?EtgiQ)QG@>xk=+DBis9Cz>H@Dt$6Q`JaI4lA@`bn(OVz1NYGk{ zhY=Fs@sGNCA8}hB71mFNe-*XQ56N$Ne-)kG#ktkszGjbbnWA=G!#-Z_12^91n)6?X zR!^yTF77*NZlb=??%F*-u4LYa0yho7EJ?xJlZySz3xSHK6rIzv=IiKwb5dT;4r@ni ze^)tqE@koU&6$!kRkF8xK!{K83~cSc86}ywa(}|P&llMZJ6C43TbOkIAXb@hS=B(3 zr)|WM>MNkpbj?#uw~qZ}fn*c;K}RwiZ3HnKjO1~ia&uW;D$thF@`Um($`&MZFIrw*`PonA?d){6ju-#g0F-Zi87&9mgcPgTs?re9f zUl4z79|(LS@PvBCyvYr!TL!n09ZJj8DvtYz8Oa#rSK@DiziYn|YCb!-)O8IPPw^I~ zc^OeWe+tiSr2hbEG?3dy ziUA-w849X6>)RFhO>YV4LbH`xzT$iA{Flq2`d%N7D$0v=VIzNiFTR*bLXv(d|VwqC=Qy5@Zz$Lh9=@v}hHbnp2K1_I&i zq$~dbET>S#m9juMQIm|0YMr-^;TBf&M>c`f`)&q0&Kc?QuW9sPdGtR@ zZ1onK?5i62!f+ZXQyW6}&%gLrNgC;}#AA|W^DW#y!Jm}lJ!|1Ff?v0Ph4qPSbWJ-} zwY=1`y=`rlNlVR*&kg|E9eaghKMLu-9e=?=d_D22#2zTp*H!-jgnv!&;=rkWE3eyi zN1Qnk!vgMh3K5GPykmo0H28NPTbWdYi;nH+t@vEhrx0N%{_#mky^?9a!(-|<_@{2Z zG#VaArFnUF zAM06#96Pn;E#GCS>Shv+95ki!)6C?*;Fdlgw3k-#7QKC_TUq}AYupQ4D<*%QSIazw zRAVEC9_Nbt(KIm(Vi_7T#y>YPrrZJ0AI$z$`ET%_#dd!VzAS0F-Hqq=ot}aqxce%z zTYx}v{{T1#Zccf}8R$r_x%@@@P-gpvD0La&(h$t zj(317yh>G2mtoT+u1^@|v_EH$g3a+$U5~_?g^V__z0TydLs@N@=G`qsw*~ydFaaC6#6A4ri%<3XPIkl5Be(Zc);WjVN+Qu z-&)1B<8LJSh=6>+oOKmqTT`fLJCZ^g;Wsmd#_{QzucYa5rl8YB8xJ9OyJv=60seDe z9V#%DC(onqkD!e8_H7gB@<(m)Q$?EIG(`J#up7qaTuQhg*Xg`t1CL(y`I~QGx1KEU z^wFavtEj^MS^La9#jn#OIbZRw+B_lmaJ843{Ey0fOxCpzNBN`1zi&&6dwn5J9QW3AjP`+>U!i_)Vt8;4NvS)GXTaJ7VcLjE96i3 zDM!TnKZ0KjX7Ej{cN1&gHbEuqYbTi{y^==7os&3i&H&>W3twGIx_^fJLT@iM3s|GI zwJ|ECz}>N>GP^O(NC)y5^smmZ_$YP#zr`Qg_Ty6U{-H03uNzH*%4>B}#s-gg5urSQ z7aZfC!mlQcD%I&mO8d2Kv`=@j>u1@N>EY)}i};$qbgX!9!M}-KAowrh3w?Ia#DCcL zmV{drf%7-WazQ6K&PU~6QfMEx?~XreFN^;G5ByoH__S$9;ey+(VdA!Qdr~r%ky=)=O=4)9yZJBErg8xW;!$ z-Cp{w52CKWVoh`Qiuk4AJzK<X#N9zo(MqB^#peSDU3Ea!6DZ*P|w zcwq&)B>E3Nfiz_>O%8;wAKx zLwR_%*1E*ZrZ@|3I~iA(1gHRjG4Ee=!Qz-TO*+={*{x-@xQS&)ft8iifN{|BM;Onw zeo%Z<)1&adu?4-Pc5_FiphZZOJpN$rdvG#(*QIBzk9yH1i#%r*Cy1l(#oKGrm3Dh{qZ2UoU^bW;AIv z8*c$!NZ(Bf4g%g*q`QZ*ioEx#e8l-mApQGMvftBpPsrqH0`nOX4uLvE1ao2 z^LE_%3-+A-qI@snuZViKyt7^E3#{q|qg$(Ghi$sax-zJpxC885`@p}H%%~PxJ z+vatbc4JTQ)RSCXUR&JWMd!gBv-yhu060>4JRE`Z4!Ep4e-hZ<_;Tdm$uzBS_>hRg z$c?n)ACF$Om*v~(kWVGxKWZ_o&=CIsb+USn;Na)dvF@)V)@5lNHr^ev(Mjrm6U}?n zFtMpVS#{gx=5XRMGK9U}q&K>Q-D51nZSC7~ zBgW?r4%PYs%kG;HP>r-ey7#Oicz$CSspNhU-&tw>o=BmHoB*Wb&Iup zxGJ`QA!cO*l14c7uXy->;(b5G{uch;)9qp~WhOFiJBqpvcH;x_?_7VxOZ$(38oq~h zapAxCNwl>{*&lx4Yz|gdbQlfSBl=f)`#yLjbRUK758Ii=)Q;vb0wGs&$Z?#4szArJ zXNi3DJ*r;suB^HdsX}t4P29!!lW%e2EfKW~9UeQ_Y~zw^kqkS$wBvRTexIFky03_S z5Bw?chLPeA8$seNV^~X~Ewn4RHrl#ISQED+%>eED*-&%ABDQ`lcph&IXc5h(T}NXY zOD<*3(NTlW{m^wi`CH0DkHNyMe(t zCb}|pPP|%+w|C{(;LK%ND}Cp4@f-dM4dRV^T(>?QSo|ZW*~{kJ3pkU`jO{7Ie+Xsa zfFIJnvi|^rWPaG+4`RBLLAvnAhkP@vNwnOg7Ra$&Hs%6Y8)<)%GyBgmiO)_e;13S` z64g9)c9UuGo9kFelJaX@y4fc%_1h~h#v>RF-xa44WVjq=23Sg(UuPZsu9oZ6tdA9m!eXOd3h8ft$L{Ba zB(>BnZOmpCid+cUc~?{4ckuu`LBqau{3^f1uN>*04D=g~b5@4o?3Pg&Hc@$v8#3j0 z0)A1)ulQGjf5AQeHva&^-1uh-&8JVJ&vMhxErXMr78eeW9$4pxR^a2ceKn~3HqyLv zq0bC*ERA-dU9dMJC?N63`d9Lwe7e?}Q(= zZ-@MQ>Cj%P3)$Hj{>>p&WpG#T0#%jr2>FG5Icaj% z%e%`}`XUuJv>@akwD6M{o$V`Da<3f;>T zQ+%FoILPb|C)1kP`1SCENcgkim$$sPwSw)HT`!bvf0K}*y#o%O-t~#_Ye2p5^_BEj zQ%JLmzb_IvZ#j01k&ZF=3iR`=|(dVY?R z=#Y@hR}8~7^8UgwQKuz!rS0aAdR64-&l983=V|;eCfD-T#3%vB%AlTdDgGK{ zxwyY!HeqP|zc2oJgN(5D6<1E4>w3o0tZ>{Zl0}oTJY@S-%?9qu!&KbTEON)pi;b)P z?mKkP<6WN8DvYHKY`XS22-`j;Txd#GOC-ertPKMv>(!VY* zwRP0|cc$F!eb(1j{!=6E&SgXWKRW$KMZ%A})7HOY@Q%Gn(8NKl6qn?FSK?G{T7~;R z%^W}N6XE3XN5Xw-&BNrvB#bMSJJEwLBis%;9+mlLjrDJ*FeHh87 zTEi56T*rkwnU3xW&IUo}(!V-x?e!}w^zkS5#nW7B3e$Ot<|(&@7&%eXoO*FzuiiKK zZvOxi{{UuBA9xn(Z%s5hvplaAjmOG( zbnLeNy$_MJX!Q>ic(E;&n)2c{kL?dN&zQkUIV6+M%uZXsHT7rg(W1fe1H)b<@XRt> z%8_h%t{V!mOown(Bxju9{{Z!?-Y$ROoxTVCpMPWD5qwD0F11~9O<8TA(e(SY4A8u! zKILFXIAh07E2i*I!A}h6dJl$V({xKc9`er4C7s|VGji;|F>&389Gc6GxS9~9^=dKq zx_);IgCwaby7GSNZEL>Dqyx(=8*^HHctyV++aUtn-1nIsP1a`q!drI(#}k#A_d&G>UMNFB?Z+*QFZu z_u2G&e>i!KaVd`lcs!HW@vp~fW;K>3IFyoUr)0GGTVw93;c3omdfXW$?clgV5I-;&0iT zz`iH(JeroH;j1|{i`GGNa2Uqd3K3KmE03E1bKA9kiT?l*I*k7SbBLj3-Y)n1&yB6Y zHCD80b6>uVe9`mIgT4#tTAsNA>J6xA_V$-{&`LFnMVix7Wj{BNgc2nYEA3;?TvwES z+FuBCmH0hx;sYddX?KywsKsd&%+ahdOC+fh2qbLe$AWg_fHFH**n0l}?2BtHx7xfr z;?Ed(cH$?*Q^BcyrUg4#Di|gXHv6d!j+o}WSN4|h4~VaPAE;@66ZAImH-#h8&73;! zhYn}7x>*E}+}&HBnJ@!K`>c#P&r157yNL4qbyz4=<(k_1Ka>9ef_OQmL65^yZld>m zuKx6Xee0ePkH-H12(4^1`)kQAH1=eT+!a{ElEHf8j(<9)bbSX*@O8b!w)d>ip)sQ( zt+mafY2fll`g|*TQ}zxBDHnoce9a1=2XnO?2B20kM(e+Dsmf+k- zRJv{-d2?P(@Q2~aNxqMD)wFw~mPX3TRg`3~#tum7=xg1yU0%+A5nM%fu*EcR zlCYHA5O(LUK9%Cury6tThh5F-#?iEySK{A;uCKftnxt{I^gm2%SJAADs|!NmP$Wvuw%eEGKU0pKtC>C$(!2}e zdmD`+#^OyoMYR#mAA1SfLAjjr2JOW37#w@okIzxklw`f#_Bv`bq^fOx#UGD0pYV|Q zS(f5>Z6b?LxVX4kCRI{QE>Ac=G7>Np;PQPdCcM6#5&)8xYem}5$f`RXq-WFt z-n>8IhwS6v-G5YTjcZoaY_!O+EUOf;$0TtnJO2RI4p=D+7aZ&&y^O9{dU$uv-JZH9 zb@(1`JKM&(l|8)w0Lh<3e#*WB&^#HU>s})8zlt@6cdt6f=j+w#>#b7~l=c4;cjFx&Hu#KN7Wl8^`un8WLIfZ^V{z ztm`UjF3Bp1EBg>;+`SKa>UcGZ;2%!5>)) zvAxpR$FS2h-Fkb8^$W|W(*$`a;E3OG7(vtPp5)bUhCUfOZoeh5nro=$n5<1RZgg&{ zD#`Fm>37$YLMAPdDap#@_5)XM23L(5=as7G%&?H6F_t~Jugx!w9y5!tCi{7iW7fhZ^LlzGFdSEUKV=;;366rGGE~0JKN^6<6YajRu<(&7x@R&c;Wu`R4M( zukthck$_G(z^|as^17JJ!t~>NB)`h%lft>BJSHBZslMmvr26HprNzuMYtY4O6h{78 z3k76FW1Kcda0vN(mBo2BnpYhEDLBGt78((GcL0J%GEWb@b)AY=1mCxAH^9`*Jw?DybZH^y4jdRCun91~pX zYilG?zSZ83iC5f!wey*_NeUIIrz_7*Ed3ABvRuk_D9)nUtvmk!Rz27D`grZXW?zNg z502d}rd(RAw#8y5y1UYvTB-v?}H>}Orcl8V^CVbFa| zPbG`p3$|Y_IMjtjE%Po&9X}e*doL?bktU8g`7DYTEDD@g$W@}K^Ej;$(HSP_;c83X zlp2fc-1VtegkLQ=Cn8Nl#J7t-*=(%*!;A?GJ5(^}+3U`GRXaT@=U0hg)>aI%7^jXI zOGhkcB!l0R&IzrUd`$xB%C9lphhpv-!SBaz-<437@n^nm%(6!Ab}mY`PZ{acnk7z} z#KtZSdEamNAv%tx@21Cs{8s&fz7}hK2-aW2x(%m>d_f9JeQ$VfUO3)J_b?7O9vhX$ z4`E-L5o*@@w}|6rfnX~0c@xWRwoxLs(yDR7o}#}_e{KH&inrede`j5LT+{90xw=UO zf*G6$;PZnalknr8Tvz3fgYC?)SVgO8^RmYrmiF<)%&OA!A(XaA!S(1X`VYhWO-xQ6 zrCRkIl4(Z!>AL&2KON&LbSdIv8gaGOy_Nc&o#H(o#CqnXqQtTJvO}~LZdHCcQK>faK)4Qn5UbiI36@r|oo9X1ODdDaIkf>x1$A!7-@4A>dRJ*!*xd>7s& z)@*z;Yjq5I?zm-&z{ePPCXNnR&cyEdPC9UXYwT}_pR=}`bEMndOxlIMp?hSoO?zRe z-dHoMWbN9H2UCoCA8Pb-o+`rSn3&bCtF_Wwzdu8dE68eQG#|3tzxC*O&%w|5CnP@( zyft8R3@gj$`QiL76?f80beJ}c@UB$*# z60%xZ{%7Yu!V9fotmcMoo>D}N(kz9QTL5xDOncW|@eARewW)Y=+f>rEXyd+yhK6ole1HNn7ek%UK*Pao)hV#JpK5R3%lIF^N!)pZtYPNbD6Zq8~L*QSFd@JC+ z_mIr8d2H?@Y~*7+W0G@T=3m8F3{@FUjYQIGdv)Dw{{RelRkA!q3&x~d^gea?OQ?9C z$3Gn|+gjA&iq}rMXss-xoT^Ir3U?stLN*81zLwSPe#xXtwtr*PwB0>-PlxheG(AhR z9;>wD>0ch}`X`9K5PV#>@fM+@-|2RDu?d7x+PsZ*A!H9cz12df%He-Ltb2cAuRMLaxlsuP7zbwys<9t}#>XeSv*UYG25 z9y-)C9~~Jkq_w)!V7gnIYYXUZqm#{b$VUpN?u?$ijl#V5;-BpE;zY8wit60nY7tAe zOPL)ZhstzF;v8omEM$Y%kV(ns%r4y9>#X)b*63-Mu25}#~X=OU7)UefJJ)}!^;m9Do;hfnmB3N zwCz>uQTSz|c*_3(#M+E^kjrg*sMy;joO4@wim_Nh<*H#u(0sWG1-T=U(A1wEJ|5oq zPUpiHO>mbQPNi=Z+%ww7(X?^r3>ddOH$m-RS9qi2&X1s(JVB$`!J|zbm{+(7wmX#A zRPO7^2N|!J{73OWUihD`={_;FHgkAp&h-VYn3zts@mk2!-6Vr<23kZ=NWncbUq2jG zDzbGYrN6JgvFq2R8c9=|z5es1_(S1~FCFW;6{|-jL{{RrucZW-RI8qCNBv7OY zybN5vG4nAv0E~1$O8p1XG)XM)BV~zWRZY)A6huIL)2!m$!GAazP(9eJh}8<~D((wVTYz3hqYG za5+Ezs>#u> zmKOQQ+E*AHepJ4%9(gAhrokC<{AQ$gYh!;TibF>fLMGgKEz}Y0E1z4a<EwBwwEO;N8D#TeN{gb>uMb?T^nD z`jr~uBpZp^Hk0d)EAorPTIbvEVe(}1VGv5O!c*klK`co6fJf55R^pMGC2l@o4Su=d zOko;0hP1qvhvq&aPnx9!yG!#&KkQ7qz59b>^ zWpid7zqA^EW7wvg*Sf zLF->N{{X=%67#?w3AoX9Z8qZ0JC=DZUM1ZF@Li0L{5zC01?!b3ka1sJe#F{u{3gE; zEOiFA5yu70w$QhkpDi~B&OVGeuaV7W<(fuvpH-(#HT@6Mvafe3s`^GhrJ2|M+S-WL z{{U`F`+J!e$+gm~3En%;GD{AZ;j38QIAs0OgaJcirzQH1qZQm}+N=^n(mL)|0F_>c9dqx<>0gp@ z1Zk|!Rg>#?+g5#M1;!b53U8zr`I74p{{Ur09L&+PTo4L@{;M6&}aBkcDMEn4Am6rf-+E1%Hfzdrsmd>p;; zw~yw!_=l|BYLn@x*7mx5ku))CO2n6RAOW*r0|B#;Xv*^Xc&tYzD#|Hzz22vpm}4tp z<9czseO{J2@7cq|8c*%j@xNBp{8^yh%VVX*CcU8Cncr;i2lA2`A|sHe_kd$@$;rU4 zt$aUyFT@)!?MUvLNLgAysDP~4BnHnXp1AfI!LKOz3FE&DjU_Z+4OrW0I>w(Ib}eNb zvU$;4va!gKj!K*{$83E~ddGz>ZM0idUEgK-7X=&T-~c~NRdEmRW}(T+2<*K-%-1E4 zEK|QUp6~Se31Q*M46c&0N4Mt%xFr2Ln!l>)7TSE0J?zY^6e9ti3F(j1rBv|^?3bFn zQZa8QAgsL<83RM@)CkdQa^O@h)$TKeV>1d2uEFt);c}>!#dH zj$?v1z&m4Zgn&8pt$jDbHh&4cGkvG&dc@7*tzz!l4LeU*KzZf5bwzN*W93rJNX`P+ z^j`>Nk1fjMO02(mE6Q40*IRwtpNw%9FY9>EXti1-zGuJx0Kq-J8?=5R*L+=X8(v!4 z%AOz6CWz!lk?%Kdr1A4P+Di4?-oC`}%XMfvNj<|KG30+%XJ*OAO6()mjj~JUBO}~rByp41r%L=&h_eXe zvnsWw^mkDI01CCAk@~(zgLpg~DD~=gekaq$qPD(L!wiic94QQ?h~(qBtzQFpV%En? zTSmQkogDy}4Po?Fn7isC#cs7Wo%cN0Y-#;LR@<+<;VY}Z|Fbq4=`{OwGkvsql6fpz zRE!*|f!MJGeg>-n7}8(7dM9H!$-8o^o*(;Ud_OwJ#D9l(ntawV*k5WAXt$QKLS9RY zX`^F~GNkovwgzx|;=Vimss8|B9ZTSEiLWoTcyG0T3EV)lT+a68t}TMU1(kVhb|aqE z^{4Fx@iRuX_}}o~!&jP?pQP$iJeq{?T}pR_3=ts0;}SPG=c1l7Uq%f^-&fP&hwOXq zf)J~nqA*Fscx;o&<0SBF?XVnCjme>ha*g98^mq4{Gq=xk=W$ss4>9MAw3Jl0F~l?Gvkj`Ca!Q_f{}pzr4q<*{5~{{WF+hhlNnF*PMpqSpI) zey8c!dNHF%*{zbXX75TDcZ{avF@P1eHqb~NeZLyXz3@e>dW=_4-3ZK0wFU{``r@P1 zFK#rb?j(h6=7h^1AU5U$InO@-0HsH3uiV?g=S<#oIQzSP@Wp9Ts&P~Df1xUy=xK$G zY_C4gBr&46h4L`_`UU{=g19;SDY5ETT6740+>=If8*hG`ex|tnN5r>Jt0KbQWHT-# zbxiF8v8nAfg|xNZ9Klt|42qli_0MDNS^M#)yIX75;l za|S-6pkJj{*1SH=s)-@;x5ERVT^{P`rS5!>oGttE<#6_ZJ|_x}LF1sNygPidyZbdj>MtkJPm`M3uias6uS zkw%e-W5Y_P=L`Hz-2OGf>Dsl;wyd%*${G^BZ}3Lo_x8sXxZX56bW3qHyGJtspt}ST zA7hX6nwk=dpEUJu;ZiO>>`|ZNZ7NHBC(8k?rivZq*x`;lXZqLcZMGKPKlesB_OHk< z7i#0gk`Zx@Iok=DA8=(S1mn}1{dm(R^K{#G+&*ZO=huN>vv^rjQm0-mbz7@T!J< zzS&fu?LX5?oV*Pxe%5sAY}ev^P5T=7rrkedPl;Nu+2fN|U0eG??kPsp7%qznDEp|p zus_+rucSZVpWhF(iw}+dAKeZ5L*jb|ltyLS9m^*F0GU6#oREEfUrO;SKZVvmvmLjO zv>2`~uk3tO1=KLio^;PMzU_dF02d(V*b4d!_9wBrx4p2u7g5~D;gT%uospwqk-fJ1 z08TOWuMZ1~ql(O_;vlaV6#X>if0_CfI&q_u;vp^e`8~fA#C|sEaq9m7@KfK47aDwX zeV;+Ii$`~f_QcR4!G&f1@R-O+_a8ysM|&i8O9{fuF9@!FUEKk%EYW;Hb$9;&1r_+& zB<#xRtY1ScE(!Aj&mJ9DXbOXX0U+Rxwd%T=`!(vdzzZ8LFMuT`tQQ+KgM<|4y6Ql!Z1F} z%&f9R{kFRD2d!Sz^#z&*2*Nvf&)zGJN$=jQSnCUHB3p>$k*&xj)M6*>G2gv--&(1M zyyqo!+uvJsDPh##ElpE7+aHAbPk_GE);dgjWS7kx&fCiqCej-$$EHv9tGY;<8#a#M z8GOcM4CE@Daly~w#d04W{{U+3Kj40=rAMl1I@Y&iu1vO;Nu`j+N!hlfXCRUDU}pz8 z^)=+b8a_Js)590Ke~B+7)_fPG!z__UXKeN^g;?7^9Ps9+0$O~(#O#kdQ{ibM{f+7v{GbH#{*zHivA&neKyyZ{?wi= z(fmgkO7dBtjy_2PbFdu%%|~hRR@2A73?T4N!W$c34e2-X82nk{ z`$r*f#0ljhMI(*#PRPR&@r4_K1e*1-z7v)Lr`GccD<^cew^grx%d0psd__EDeTEJW z&v$+NkC?8$1$dLdo+I(~wug1P4+H8l#dE1k51DKcKX}SAdV~zE&Q$P4eH-vA_Gh{A zH^#3F`1)z=buA{tL5}-NHtmlp;%%UB3^EGD7;(TPXSI6w!cPWiKM%YEZ=~qfxwVgc zGlpQ&-EUBhEXDJ_N5)Ax&jj$ZJ}FvDz&CrFYL?!A@Cb&BJoC7>W@#7qpdL5i05W@0J+!f1$8S3AdxqVN z{nP4wF;o8l!bK$VL-umBsw-{-_(wdSU&g9k>Xz}^wAa`6YRuVTJB46=qd2e5>CvxN zYH?cro=4cmH@i2RPX5JJhHHrCSqB+r$UTo7XPnm7pRCy1mXCLtCkg-@UoV-%it}TkFeA zymS4cFseju4(x;X0OK|HTra||wjol?VAVFXYpdA#ULvoH!{J=;no9cmA2RsQ<8H0{ zQhaIgPNm`f9_rp}>+LsI(FOdG%_f{Ji7SQxjf}!f;NXHYUbsFfd|dFCfMfHt--uIq zk5RY+`h7FR7O8D^9Bm#*NcYAgW^4h=af9japR&jN5_{t{hl=j?Z3n_W1N$Atv=;Ec zi06B0pfR%_mK-rCHxbvbTKZvs;F6ykyj^!=Hl?h58_;g7(P56}-$|Bs^22T<@$(Q# z;~eI{R;`1>)0H|HdQHZfTdlfj`k$Xrrz+GZJVqf$ncsUoyCd-?`{A9h#&3ylrKgJF zv9`WZa@RWCZ)+L4kr2(iGnuX!6$`okY@7kNr>kl|vv-30H=)kD$HfgF!g`;Z9;c>U z-GbJ$PXY5HZHhr)N}%b$_Z9jd`!W9j!MM#Vw_YIeXT49Hq{?DHjH28I$e${U@_;igQReK9V zYcP^2z=>p820NVfuY{NOezov_impGlE;PICN;t%EMwT&3@cbb;fTI{rIwz}Jza!|~Pvd@v09Cx!8Mh~!7*Ih&oM#xSn!oKU;XMb# z@=0~8$EjUH8_R+-iIImL``6f7H~bQp;3Qu={6B+KX&+>g>Q|R{4exrB$&wu@U%dDdc za(C}#q510{?RDYH8@pdImP2xh`-PG}F3pK$Ekqr`Ug`0LhQlkUpInPXw zQD3Q+nl;3*w6WaCBpC^kL1I0-^H}lTWb4r;B_!Ap03fN!F)Emg78?I~!onJu)~Q zvst%4@JfG+SCFLBc++@=`$=>66xoIRuRQVEt?N%SI^r zJ)bZVg`6)be;nGL?4)7Vf$pZ>{`2(|muT+3SsG zWd-}(8REN|Ex>)!7hyO*G3Pyb^%dWK&7KyM!@9SLEF6hi-u@|LWrxYPmwb6`owB)6 z^*mS3?fY%`A5PS-uB|nzi>){1mdRqXXI1&YXq-v6W=00%(>eC8qr<=OR1XPw8%)%- zn=c0I`mB=qV%pjpq>V^c^BZp9!;IkPp4|5NJg?b8q-RkjFZc)Q8IDtqtAzGcVJl0c z)6c2lzAL$gKlm!=#+6wil08c4?x9vC3$XHIeB-*eT>96hc!K`s{xthUJ5fL@w}w&Z z2wlyO{DlA!m+eTv2O2{;z+sYaELxh{A zg{2%#RT)MrM%B|^mpd&l%o4BKm&~(|0zHRujGylw{{W3#@s5dbvBz|=w03YDNMwDO z^yBiY9vc4uf>nHO@icC}Ecn@>=yvx~Vnooic9v*{GEQYG3FtTj8O?e(f`8zR-VL^t z-FU0w2gFYnUCe`YbLwAaww1aJM;1!-=uhQY;k-V_)%T~%ua{2mL(HwjnOzI}N~!76 z@AN)(_`l$-2gY6zhx=Y@J4Lg&7goB0OcH1&S4T;s$rxqWeMdby5yn}5!XNG`eY}kH^0r_;cZXw~oKEZ2U8> zNHjQB8?7GAZX~g{FgV_mTQA>bZ{T3Vzk2mg7X?oR?-Pv5& zX~N0`iMUeGxKAsN&bb{3&3?9N_Lov_jiH7!aWRuVQVAq;l1mZEJohw7phGtz72}UKd}&V){_BTOe|h;wrT)o(wWhg0nGeFP z6T-I2-e`SeU$li>9taB?Gj1J5K(AJx{t0jK3&umk{{U{+V$g30S+#;LS5aHBjjFM1 zL@_8Qc3GI3{TI`;=ydVBJ6}DxV_2bRZnYbT-g!q61M;XK^PhZm zuToqWl}1r?=_cFTdt14jxbqz=IP%5%bUp^rf8dHX_K^LTTKK8r&1qXQGTB>4Z{;1w z_hFfsu==ilr&Hkn0Qe#o!QCfNisMrM0EFknw~G@7hW`LuXb_GE-Ha}O6WsA%VqbVN z^IpEVYss!H^p&_DWfpf&42R|9w%#+l2M4cuo5Ow^pHG=aq_+2(To(CBrk58M?xc^L z^yGg&mDL=h2Sz^aIIog8@f=y0)r*X4JHFZ;3-G)C361bC;rhC1e+#@fG^|6rUc9%2 zu<69Ez#clTIp)1{M))(}&xZO<!V;1RQ856kSRB@i9b+1CdTU&OOZS8i+ zk|T)}e)H3xUY^|6E~BXGo+I&f(T>U};nCB|wsyE5=`-@I#0~-{uRN2UmEHANMv{wG z{{Vt{weyT@qaJ#V>8iUc_+P`VV_>Z(lMTj9ud#;Qka7ZycEHEAdB5#dp?Iss{ua^v zHQ{|?{{X|9w~DSB7S$}8684ZPZk9(9hyER^r9e4h>6+=hWvXl1rLWkmG}){-LK;WU5mC;TTgZGlM z+UYaRt2)(VerA#AnoozJ(JkY(^5lQBN3b~nNgy5E_ayRazO&-(JHuB7Jx=1v&6(CH zR$H`JWsVy-}V|@dv)^yu8xyQsi7u36Os8X)50+jaXz5L0%1gG2!R%$L&k|I^IIIQ)=ED@h+V3EH^E& zr32+tj-c*d4`My4=fI60!#@jr38#2-!+I3jHLi~shFfcEVH9gG@w1RH&c~hK#7=wj z7+gm`EGN@ogngU$Pn&)p{{X=YJO!d8(|D&>)I3jp%0Viqw%ZvUaJ^1*$PORljjSIv#g#&O3=NZO0>&<%yi2Q%2ctcK` zQn<6cv$fOVXVvEN!rrQJj%i^>OO;K0?0W{x#4pq32p#$+e@B zMF+|wC3xf!P};=|w(`2j%_PJp?>h0%;Z9hhF)z%%b1?`*mLOxf&mWCvDl>Cd-*#g! zk3rU+AYuE(RJT_`!!cy>)c!nEuWu!}Qe@i{j)x%u$UQmq>-DNDdbaZIWsy9fR52fX zeNWc3Bk>))s+Ge_SuMNv(s4K&#+)Ubxl7fZd zXv&~m3~)L#iCiyghWl1;LxBLpbt(?5?D;y>_FPmebL0JHDxXXB3%TltXc7g~F1Hbw=^ zkX&5mC16KQp|@w+y;o52CY#}ZEhd>ETgevQ2;_VjTO;ILo?F;+UlM=8Z@gJHtD*cc z@Ei`&wzI6I)z_5YD@MEp_a87ng?SmyA{1!4>*D+mV})=p6PVMXKZ5%D9%uVAcy`~y z{{RCt*p@#sX0|g+8yv|aoE+oafr|Z#@V@5o-kx)e+eQBXpbGr7@Ew(|ou$O~yIFMf zREhQv-m-C%>N=mOuhgFlNhJOl(`0P!dsvyV>QED2H-|sGo+hl7oSJ8&;ymN+DX8tK za{eIEup>Nvb@AWrC;K+|N8;DU_187u4(XmEy0Nv6)e_{85lL(TkL6!tYAfbLCeze# zeJkbPAACjdj)mj9oj+3Xe};5@Y8PBXd8ow=)c*k8D;Z$K5^=yiYwU9jJu0}HSK8fA zkjd*+r#sQVH*~+xh0w{uBH>(eJL{R+8Z~Xl9V^B)9J*R$zGtJ#s#n z>AXMtG3ZwMznKP$q3P`$F~t?Vt&D1j#gFec0fUjo20HrzUW=x9)503evqxv*4Ifav zh!>S$xK{Fr87qb7J*$uSouf_Ti!`x|#rN7xqN!`iG>M>&J zx=D7rTIhP;T*lRIl%(Ztb#COEPltRJXRZx5O_N2`FXMOfEuvW3Si&5g$_C{FfPKy@ zaV~$eXI8bhw}D}1!l2=I3=lnduaEx#V$axX!oLrHYS}zV;eUz#Hu1Ihhiv7wzin4+ zQHHp0nlCgG(y$6LpzIjN2GNhWkAd3aWRXSJojGEOmEBhn74JK4*8Mtsj%cQ(d24?sr|4X;*CW&%<}(~@e(_9Tq=3ISKD9s1sU#}7--l^S8Qxf&3)%C9Dua1=#&ZZ4ve=27MC$KmbkPfa>IYoiyFad#Bal8?CUEz_rO>s?f9 zKaN)W_Ar!TDRay4K3M&;{{Uw{iQlxIp=BR~{wGJ^9TvVwx_?L04e$F4WKDFW#t6g}89|-s=NLNml<;Bjec^N`(1pK*>Mx+39l5@^0+J}Vg z;)8k#A(lkQ#mCBb9)uJ4{{Wm;XNw|#2KYl%y0?xitE+uB>T7#hk!BY%h?^0WEAs|8 z9r2pa9Vt_%%(t`HsM+15eqR2~U+`ERJI5au{4L?#R!dzegU1;B`&ZyU#eaq#De!NM zwQK9I4Qd`5)Gy*X&AMOM$*0JWtfXQWWd~_0#{;jw74@g=TmJwBMe%pT$-FP&-ySmS zz7*5tNoQ+|FFyF&9EhNe$q1yCQVw#z?Bv!eN>ur6bgk9z`JLXz$xg@W2gDzaUK{*t zX0p9fH(a1379b3%EF0xKiu{V!zha+=pSFMPM+x%-a*N2_?7`(n>i)HFR%$K^>hM0qkMN##ydo^hNC z+lNt-@#M9J*6g%zhVvI_BiQsOIO3Rf2omS_mhs0oB&5udl?#A-{XaUG`!9*exY2U= zQYidI@GAJT;r6ki>iXo%ruc#>Y%ku{F$AAxb!g;J>M{$4``F^NybWy+g*-QBrrX-L z_I{CV6}7aP&QTN~`j7yk!^Ji?&~1}ZwDQ@uvMlkj!(?EL=aQ#^lf_hn?EW3yWvAb1 z_p7+b(93u?!nytx&wiPrg?YI(&Az^db#N`MFMmUQ{7D_n(X}wD%m^yvWmI>o7gq~o z3#4qnFee|c70Zti#bXrC_nL3nbB1=FEXGL5T<56IPPEtYj+1G9Zx!aXrTKGn6D{;A z?me=@rFv14wfU9P@2g&iEgU^XJ4z4kW|hx~8qVt1&bFE1yZLu7A{=)=pS4=K))v~& z%NS_blNtN0*(35i8s;y2cca0%hPj}Qemd}UaP1@an%%Oh=RF4e z_2=qpxc#!e5LsWix%j8x3ssG_W?gCl=Ht+A?eFRI{{UwFIU^ApN!pk6zM9T&k!5QygoPW1`dWVjG;G_CBlIc9!e}wF<<+7P# zI09HhG<)rRsR45g#DZ}Iip*Di8^kpJd%h+dUdVE#HdOC09Zgzx<8j7 zbgzuAziE9R_O1T_f`jW{@P)3a;%N2DBld8XhTVsdBTq1_m?(fY+&h-bqJQtSZaMfyJXHBT6?^DLlledAMKj=vc{X`c`L0-EZ>$C2M?x`mP%ZfvDbK17JC?R-R7 zHdP7Q2|S#EU$8$8uVcITf1yl_Jh`;%c@Um53`$2@_INK3?43(Rr0u^`=Q#OZ6kFe4 z(A&9P*7)_!e5dBaB5~Uw^rk})&z6)%@kL;m_bVr zuhAbJm*Of?ZPRu%BKU8j*jd}d4UAI0QxrSf)TY$jPFL&NOM%_?JZf`yQPP%M!0OEkjKI0D@Kg zCDpto{f{#EwtX@=3Pl9E!?8Hv5=iW$0Cvx)6^=jPjQ;=!tmpexo#&5iZta7J)iu4* zx{-h!6gz-k6iLvweq@JJ3 zzsSMX{{Y~cU$h2^cjjq7wEqCaTR5bN)P18?mggh^&jc|0?a6Jx?#2c?44+zqN&SI; zY7I8ZIi&rmJWR;NgqHAIW+@boFj>PU+>8O9fE@Eh6iWpOD>nY~O8iZVNjC~_POy*p z7jJ)MZ`!U+Rm6X^HQcs#+{GK&M$@hb8FmD99Y6hbQrGre{j)AS!E5_OhkJQ^(z>mL z5=Y3!#Sz9%IpZg;IULbNcVaNGr%Lf}qTI?{MP6`=bm{z{jtAn0{1Wr_sQsdJ`&}2t z{{S6)DdHy0x3Ut%FN&eKt zbMoC=(SMPZVgCRG?X~k(TmJxyKOA)>cil2w__j$2e!)l|E=D=Q$G1w=)PLZb{t?#< zV^Hyj#!0-txflE=`iyE6bXM}*1|53w&uS>Eqec>!yoD^&6M0LREx+KE9}37iyQP*S-g~$O|5O~jg^{m)` z;E`Vfbp2vWsQwvgcM>}1f7|X+iR;=?$G1)>qKcGZOkuMPG zmj)d+;D8lllHGDP@!J4mzEb_0z8>FrBlf)i0EH;u41O2*>rn8`@7Xn!*7RFxWVb>V zMk4Kw`3M_!<2<%W;}lU&loOvL*xsfhtv3Cfn!4!O`fa8E0KqrD3)sZ3C&90VcJoOJ z`H>A485S1%-zg&*2iqK0&Y$}Rd?kj{?CJ2Q;oQi|h=1apvQ{di9G@?o<1|rE4+U45 zOQD>z3`>VIoUf-><>qf#{>8r!VQ9^-!=C}$=aNQ;XfQF1V`dMfK0nz*z|peB4}d%# zsH7ED((Tbx9)VO*MRY1Ka&fknG5*xWb4%ZUd;b73zP0-}cp5qFZDsI}hVG=1$yl3A zh;TAaTwvgPSB-woKd|S-?}nPDsdE>J{2!*rC6(p1j5c~Vn6|S`b0o3M;g}f;#$rrv z#v|`V6^%SCKW(D1xAw*^oFx@D)7>`h_ZPfT`wDzo)qWiKQ%kY&w}YYA^w^+_Th(qn zITXiRDJnb2>miVUR61=|0}_*rfNSkf59yOw=oU8f7~X4pc;a?N3arYY5(i<&1GN-a Wt$B#_`XxVyW%y9FEEU4s)KxH|-QcMa|k+)3~NK?4K`@+R3y_CEKV zci+4JytiO}Q>*){udA!8dU`?4)6&x?0Gf=1v;+VG0sQ5OKEN4TY{b?_NzyLsj+aSO{00>+F%ug8rsDZ%yBey_M z{lb9RAeevZB7<@7Ay|IOm|%Gv^7-GGHB;O2jM-Tnx+{02U@@P9A0^9%c?=W;PxsZXR|n0DvSg{kLt9 zPJvl}WUnyNUp9be6h`?=HU$7;0RT8KJ2Vpu3sWrAZ<`?lW1;`Z$#D>WVIUx(0CA9i zILqRp0f>KD2U!;n^QXNn9_|-E*kkeVzhp2l9^sda29^>3;CtgkfAaeS0Fa}(zwKdW zV#>t2BP4?ObUGI;18*AE~4c!~XZXJKOE|s0O!x&Spt>&ca}hv{4+}+A%4#css{4Us^X&jV*HnQfBI)O{}=K8kT46wK36t^BXz`N35a_K4-~2G!BUw%BV#po3PeDB z@e&&cmz;ugMj@=@k?l z68a`AJR&|JF)2AEH7z|azo4+FxTLhKuD+qMskx=Kt+%g#U~p)7WOQbBZhm2LX?bOH zYkOyRZ~ybb;rYea%d6{eH@A0BD**T(^AQAGdoSlkXuhiJALl0b-^+a9mIp3*v&4u`X$HzIna4DK_uXSj0VD%;!OR**HX$*kpd;Fpr|Am=$XEm81 z(zw*>HR+sCxRQdmFYOfRO;$zA=*mHX6UVJ!rCc{!V{zN-c`!#mZBJ@an9XGOBUS}Sn_M=g848{?nAt#)8X020?0Io)JXIJYcfV#Cv8|D;L~=bI`??6G_6z+8MCo< zR#M$=I1VK{9Vl9=r&NWx9S)01?bD7w&otvY8IyrW*)+4Re}AkH@JK^!29}=CVJO;5 zNk;<5ZT+ew)9?yE0v&-1F2`=3E%_89u;1{4c>-94#vu)xT?I}ILBgcV!7z6V9?z?$ zm?wZR)|)MAT)gpuh{#bU{gw&P7`M1ZK`hXBp5)f-y6ml62C~d<6~curATn|6Fw~{l2-<8$0198l3 zd_Nj#p*T~#^Q9CvOmYg7LaXy5QX$8{=oa#ng0a%Anu&a5clGmJ;AQGmHQ}rsRHGg>8*FGEB6o) zXZ3D&*IYR>D=6^`H{!iXkk;a?Fc0tW*dDDd1Dv9q*TS11KCa_Z_W2$1{5KD>Oq(&z zx`gDNoXtVhD*$2$&brC>Do@ zfQ^PnxTdUw&FHJeY1p>!8IVx*qFXB5&djpPpj)#}Oq`hFC2nrL9kmyUP#H}U9-iwN zzIpT#jTZ!}lbzXz89xCuJwaRqshAYyvYV&Rg@hi*L(Q(F{kNNSDz3dcO9U;Z?H2Jd z{J9=z3eqzjb}1vaI8F#jnbmVN>uJBYI;^+NIqHd$KuyYJD~{9L#yJGMax#g{Pmkv> zllZXiL8jh7N$do1Xnl}_ihpA-CwO`E@-dJ$^*q1Kim8faicn(n!k=f6$If6&e>#x) z0{BeSafBVS605S+(%}{Wkz)`9wNl?(;_WS#WE$eg7OA!?Wd}Fg=&pLmVzEIOttMng zrW6MxlqEVl(sC3t4u-Woj-T8w=X5r3t+uMSe#s&XRi_(j8l{;)g0l}r78G z5^cE&4o)FRBj#Bv;M0y-?a4iJp}`4#u`V-zzv1h=`ZxYQf^mD6!`}B5eKGH61vd>m z>>LQ;_k)(nxrwr^I;Z;D#d@0O8S{~rTI>(hp8%V%9U`aBAbOzkLWl2miJ$r`oQ1hF~EN1vYdvs**U5lqiKNQH)oaIrFDulx!Ynd0J`_j_zWExmq3RmaIa_ zA}bRlV%ztI(Vi8@qORkx8_fPg(3iZ{3of9GOIgf=CTOBqtf8C7`Rv05b>nUl1A)eI z5xy1^sU5sl#wnB4BemiPOL~|x%t#JY8)iNjqRF-xl(vHv@xgr3kZ3NzYUaZS?+0s) zIjI`Y1muoL)7U3K!&-WEi4O1DfXllLu0gUe2YtU~{use%T3XHQi5*L$ua%Y4PCHfj z93Fb;IR@(kPF}$(jpx$zdjUd>*s+G0?gTM2wT#A1rdsOjNXR+WjfM)VMX#tF;MPsP zDK6e$74?I06VF#^^bU4E)bPGBfKzB3 zUY7)N(`+UU%ATb7?g>>D1KdHr@G^6495%47G-Pih=U;%xd}Hu z%GxvGEI^M$-xy!IN}0I|ua|A?_#p17RwUgFt5OW(9Q*ps!6SQe*eEt@H8(QvW!WT$ z9rCm7Nh?`?iLH^P-Zid1>K!7P!^$gpUG@4G-_w~#j6q~I{4qFC?Dn!P%&Ymmrh!|6 zaid@3Y1IzzlDcO1PgMO)mssz+Ct(a~&(-pT+eE2FNAt4YyIl`$>o#@Kz0{Pp#$<^# zVbE9Q^@S^X-IY{r-o?J2hRM8ks7u~Ca;gbxW1(E<1`_aSB5kFe=xXR6o)5qiiUwrX z``ty^FIg>e;u^0E5G6Tdgs_Z!g*sJUK7hsBc4U%tBYz7KnXu7DHTd0cCefdmHwM~u zO6Jm*_zBRTJ==Yfa;wR|O>WOKDF<0n#|h;71aW_qpd^GiEg@Qjnin((}MHlLfD9Mq> zcVZ*NEkT^5KXomYSIaNT?anXkm>mZL>nmx}5Xm5{x4-j*k;mQj-Y6~5&J$L2Hd=__ z?uva+CpbmPAT?WQZF}cnrfiJa#ZX8&TY4u!9h6i+oG=@p$(9WtNgHpuYWCl=xIK&N^1b4;RH3{G#haFxsxKyQxn?Ux$Q*wvCj zNdKuRCm8un?*~PNPD%gfe7ASZwB0@1YWg zw+nd`lV@=Nmsp>@xF@5`{0`!u05pB)b*MFxBu*`4_=LHOY1XvFGj@gAMUX2mVzF#Y zFLAD`*(iL+^E47BZo`Ima-IPCE^QQ9EN>#DvJkMpBE~{RJD|8}oC_`|@=~Q>+ja^f zZdJ3%>&Sj>+gLu17Y&RS3lSvbD*TSDL@>A|6dd#wig4SRAA^p!CW(O?^G3+}#h=n6+*$;3jamiB!zm8Z_4x&qv2OABM)S(tSAWzpgplcIwG)eW~j7*4cj(OOz zC#G$O31Q_J3I{|Q8mAexX{~5-v&dY$HmgihdVg}1(_0>_q;R4rH;WU;+Hw9si>ES8 zJV%@f?v$(V58n$WPnVIK``S=h$)2c~fwJxlaJZ z=&?Tc^w4}qOsQ3@$GeNwOBDBYg1u2!LF zwFDQf7D_O`KnfLP;E6c6XGnQ&WzUpqVx|`tFZ>=#aKF%t4K&V0=5k+#8OCC3CsodI zW-Gd6_sMJS0u4wwbXoFsHR!~9ngd=hI_a|xg+jPFIBPD&>=r~0#iWuqp#M{A-%n{MPtWqoNS-Uxj zXVod0_qmX$atkUqUQ#~2DZD8+!uX4{B|3+=a8j0XRW91 zWL{`C@9MvfUF3)zWX#j<4^JJV;v!PMQFWz0v7K}K6&i3$h_EHVS za|H4a2&|8UK;aunh?`qC*RS;I%&=Rc;+XiqsTG)yde;|T#Zc5?X>X4i?{#Wk<2X1YpQS?eYAoo3S7nqu`~WvL#UMYlugt-9JXz|_8eOz5$AS30>V z$sX@x5E0dVca}Zb{qcM5qXC0x`7%`?v;&~*(J(Th=bm?gx8tSG-4Mf|eLdQCO#V^T*bG6?9j+nTm$SU0R5$&qZpk?JyUalA5S7* zG;^bZxCG@gwR|_#u`Jy|O?36CLCofM4(hVsBFRjt$K)j5=`%e^9(&=?0gtQ3*4gSS zT2}>_+zV+v2h>Ar;~h~{|3)&zmbiRGCq-llze+kQ6>|OM;UlT(@=;Oz9T)yX5+h`7 zropluKlptt?M6Iy}t*aluF$Dge{c5s$3AFO$G+IaNPr$Ty?Oh#~!62=Uj z99Z91aZKA^*(z+^7Dp_1Ocopu6Z>>DIleztYYZ#uoO>6zrO`q^a-ml%HnJV>IJ~bc&Qv2|1M})*IMCa&^+ia}4@BiNwzwP*qnt zMf>1^XOr2fgQQSc6kY8N-T|w&o>>FAonf|4g%5D6mhtGp-J*fRC~Siiq)cugR9p*J~5MDuKQoB8jywYG}3^%aISbg(J8i z55nOW_=+Y-(~26v{sP2_yQz1DrW=scpWcxT(GdA2MX5Vk^!<2%pAfYjAKbf@RlX;H zN&KN9Q=sk}ZnA9xkGqAgoK0Vf!FB0bTb~d4o>X|5UfGUv-LADQ4V`tm;`Nf$Uf+cf z2wCueMW|n^6t&OPAG^0?J2gi4R%CVkv-LvyAUOgCS!c2QBySuYi>x@lnJ7zm6le;z z^99VgR^2LqEY{(nux<9g7Dkaqg{$*jI5~^89+u}+YTmAn#C^qs_)4|&{IiuI&_2Tg z+HvNcw-nwU4Hv(p*Orem9jizBl%j%ogPHr3{KWeQ7%$W}Hy?3sC z+Fsn9&vex0pCh<39iWQ^c+fi(yarZTrY9$D={w8#vUiOjKXcj@rX=1EQqlInILq^G z*KIWr_(GzFIt1IG+S7*8F@iPMtjWSOtbKyL$FR;XvKahS>9i{M2pK4mX{N6rtY-ovdJp4=~tCbQe`+hcW&W2&1Sz<<-Ces$_ zo8ubpbg24NB6=OHI}`q+T8k%wBW12~b4^*dF(b_g-MEy^Rqy_RDhm!v6ot~2>V&Yp zS&cM*+iGueIUKEXcMGapB4SRx5A2r+y; zGPJ2#2g`L|8gjN~T=3N$nV^#fKh(I039$40!-scFQUKhjx9{|81F`RyzVKU5%)Dy8 z$x!dwOdKaweH@LRKlDPN&(q0#0-#>8J?N;tQB0P`Co=@4MFX&=gk6oC0=C)$g(TDl zm*LJUY~AjTMwl&E8k6!vhiE%vY5?bxyc!M9-+vIRGgxJNCuAeV&9&dDrO{~IKLy2& zO|EB-#U~dctf>rs=*Hq$i-wH5%I%j)!nq4jhDVGSK9~zIp!8|(dU5Wvc%59|N0hmt zG8T)2B%aLX708UH7Y1FTS52=gy z7KGFvBvH$kV_t0pDcHa8aFpcfY{mYJAULeSrw>+8uwmm4Sej}Y8YvM}JqtJNx`ei!zJtv3`)Wm|p35lw@w2*XIzA{_^f&JMUlFL}ES1f{cS`D$7 z;!TmF_jRd_p{3%eR_^^{*X91%M%YgBOeO20fYyY0&!#-=L3n~}6VXW(&4)9Ez;Vf? zVZyhx;FF~38ASk{4eo-_%Yg6<;Q53QA5xweXWn|~)cu3>BRlq)S*rfa$vId>Ww+Z$ zm*UOqDB9#?eG!o~qpss6^UqfE-GOXWMJ5CtW6?~_EK~yfC?k7~FMC%+sVEQ>ZScaG zp8)Svh>$Jbqh7*K5%o(ynq|BUZBYB(X4}p0lvcWZb`qK6F57@}8RFMlj4P0xwrYuZ zy$L8LAX4Bm0oh5N7^^&yf{t2gM*}1 zC2Hsw3@oKZJ|Kh7_2(hVZUb6p#fIp?UObDpsoVNx`fu^Q6BfFkmY+5ePAyzq9C#QR z?VK5mOzn+93?}xrj2=b~jLZy7i~v4C4+kR?Ymf`EG05D~j-T?hy_=HQ(v+VP{BXo1 z?;r}Yu$1z00;zf_sF`?Kn{b;_3JM_ddGL7HI@p3-jEFsKZS0(RJoqVplJkJ&AI*%E zV33ok8IOvX#4iZ=j-T?EE$;5_4DPH9_D<%E%-r1Ej7%(yEG+b31iiDTor{qNy`3`< z%z#MzlS2&TY~p0;;9_ZS_lwKnA6)-3FK`EVUaUWcC+1`PVZc9>|2skpdl!3W3wwt@ zCjKk!@9v*Y`=647UG<-7f3Lz1LSjBJj7QW7WaMJ+q-Jk#BOv@^wTP8|4^1p?1F{9# zxi}M>**g(yTH2Y~yF1enyMdgXe>jGPfr*dtzf%0qoB-CAm;diAwYB}}+MmYC%kzlY zo4Ee4Nm`8fhp(ML&i1ZOCZONJU}|q-Y0Bg9dj>pQY|N}iEJkefoW`t7^vukrCiL8< z?56a_W~?BP2{R`*E1MZ5G2{Pm?MzMn*zDlyWb-qsOidU;zuX00GUi`Ko0{;L{vAs{ zoc>pU3EQ~*y!q?S#_~s$@z@yIne$V6(1ZWeXyj_+LMiY^{420*EdL_-8D+m1_!$3f zi2q3xg1`kAf9CQp0r;cmPwWx%{kY|kwly*b{Uf}8I__7Wypc6X;Ad?FALCE(&xRjS z_Sc=gDac0P@2ZL4pr6fuP>N2L;I%dSk#_=W_73zmATyU=$bWYI!E332*WbqVC!Uk# z7xJHte~{Y$3Hi^)Kghp<>UNec0@8LS7Qcx9OV1w-ReLiRcOxgz|2}_yrpix0{K5Xw z7Bg}Isae{B1b!A6{3sP+CuU~mVddgsV*WMoAIKjiZBb)UCnbAl%O6v>v@;hFwWkN) zfJ}dB{9FGYO;LMW2Pcs8j|>)&Q4*IV25-*48-6MO6Y@vzpI-g{=W1~e2YY8%r@tJ= zOzaFYVfbbAzx4dET0+6q))?d@V8+P7!1W9HyX_C?x2lqpy}6T-t$?rt$jJpf2(g2c zrM;7-i|2nC;y3OOfwZckssjjQssfH&J9Fp1(_9tg_%CTL0Irw$54rpc_G_5`!1h-T ziGy=cju0N|Bn5CLjIWdJAA+?1pdG^viU30{)#0VW1kCSq}o{}Ld7 z_y19qH}d!$R)1--FmV6SB&ITBWMSZ>{>AWbu)hgFE+8jxbjgEvOaUb$7x10kZ>4{B z{ZWv!G`Db3{rTxApl)YvXYc+S|IeQP{CWMCtHCEJ@R^G7&&dmX^XKdavI8I3oPM0l zepdQGLO}fZ{dGaXK>fI2p}}=HaIoO=AMgtv0SN&f9uZs&gouKOh=dHj;1Hgnq98x} zAwTv0sPtKcp!-$n1FoX_tJ3FpoeuyS66!~t4+H?=cfkw*;`l~#>FYgQf z;S#1XP?~8EA6odi{gC`)&>Za`{Gr}kH>O9wUQmHz)a0mV$_9Cg6g4x7&qbHp=^eYY zzw7IE-Aa_=T~7mgKdrM%zXBrqx&C`&La8Zf=>3J$85#i(F|Rn)u{#Ntdi8yjXRqTD z@JMKqIIxekkJ&5i$Hl4^(dkBcl>j z<5N&)Jy}WuuE9TW06)maww?qd*y=tlFCix@m2v94n=4>m* zNA1@Rd%lcW=CO@3mr19kjIXI#m^Y{vm&>Az%(?326XqMDZES0sZ8&VpeC?vNi`4R$ zmc2m3`pED0Oic>&X>QHFhIBX5Pv+(s)_h{l1+ltwXr?(`*-q@-LB~5o~-1?g6qq#dT!>OzD-}MwZ*EE-uJ%k-?o;P~Q ze@0?d1n)B$8APX|hb5CHfyD=L!4a&+%~E@%eV6GCYz-VKiHl!oGY-m`V5K?Q>x8 zdX9U$btwmWW^K>C4!?#n=*p(u;FzrzyIS&#zjLd{;*0j1rhi>KvUgGsJogY)b_S{?8B&~@uo*d;-sihDM?n};z&bBW1vL#1e9{r%uzy?8E z4xMR-xT=+mG1Ae-SEkBeSo?V*C@7Z|or{cU!#-9tX|K9~|L^C~R|hHfmRUd6O}e#_ zNtKS}(~_>f$|L)$1g>j!{{0-Y_*6-G3JGB#SIf6>Y9pSeg`(Cx_tIBm2BBlkl>EL^ z{AhktSr!WEP1V!-GHF+I5u=!x29zXA`#YyI0?+V!VWVZk{J#q4=)lnS5)g>YUSy}U zJpq(ufVSZBK6*3&1T+LRBm~UQ-2uEeKtVu)%l`my=ors2v0jjn0-0IhUsA9Nle7KW zB_LoR({S+!t1Bc)tTQ9LI$89sFN>-qNj_OLHRijwBj_wUwrVrE^l4UAa~4%4?{Ypi zmz@?!lDMa1e`x<`fG+5VexiTIdntdET%rFZr;jsvsQdev(Y$-$od4@DdaS}5x&iPR zt)}M+>#~j!9H$%M_02LH!qP1%a39>CeUf+ES)(c1HDE|^u4+xmUYU3e;m%-Z7hVxp zw%AZ2jZa$s1hCY`!QOO_~^mb8pC7~ib=T(TvlCUr66mc1DpjJ+tCxX`G-?Y!2(epIzqvKU3{ z^5WJtfc4RQvTja&&U9bgt5ris{DVZWXCViTVv`6LH<@K<`FOq@9imZ+Hlau>L6u^o zdOZS47svF_ryN4DIQDRo1w&z`oUN%5Y<7)&KXxaIxCxW2z6P^=ee~fOuIPlK+R2$( zof+10uH$zx^~-8iG4mB<0ap_I9z6(y7CGbX^)9QyboHa2Ay|89g^gE5Qf(kTxo#FW z0h;exOh+}DR1Uxs2{B=eHsW&9y+zZDK}kfBl5u<2>V#~mG!1QTUP=qr0ZQq<2BM{f zvSyjf*O{kdJ8L^6>MM8+WQP#woq8Sbpeq+cRE1;DyV&D4dB;{&K-`$fR!1=Q_XC1y z^}K@tru9%1mO$(-kN#ko7+m>|mW#@hRROnIBf_=v?>)weuFCL>Nb&8!#4-HP$&f2{ zhi-*hoggt%QIX_RyCW8Q?I%ES+g3$ZW@i<-zFbM$u&uC=^11}a8&Ynz&o=5;=u^;U zx0n-ZMm(0LY(_4FWh3#GNT!U@oTo7)4;sLgC=4d*;8wUK%C8Pd2fB2D%ug<2F$rzMD_8`}b;dm`+Tt197XC7BqLBxM6r z@1#SFkBfKE0pc&@bma@*H5oNw;Km@y4v8b}N zfuY~hh*OU`lk`|w$l|F#U%)86()G&-wO4W;8Uy?%v}35<L5#gKZKta8>79GViePNW2MvTr=;UJYCemZX2;!J zUx(XznXtK2hSAqDd|~`j%T6ar_U%W+1tz&DXD+wygW1_uqG^{;dCiNI3GXjwWjekS zD2qjV=U>itjUQIe`g2r=qnnH?^HTsDcBN6G8T*q<3BT*M8!X>!$5p&l5gx5LyJKuj zbQ4)UjE}JUtmi=kF@x{fj8YxoPB(jEar!!=%t)GTnIOPaQ&w?gNBAYeI0pX8n7?)a z);)*hsRjxAwMN)I>&+9uRphhKw zJT9kNzfVWKZq59iOQBEmYxaUlu>)kT+*IQ1 zJC^Iy$PMp)oiGt6h3PJ`v1W{=PBl{)GrGmS{CZ_?&#J`{W5D*DfCOD8iUyXH8{%c6 zR!M})BL@Ys$?V7eTI{LaxpqY{H@x({3OrhfJKnb@{$l0P9itg)%R+TA7BsEQMc|EpYp8#3S%zI_GQAvqQu`!X5Jz|TqyJ7J zf|b)dg{qqI&AI;kwVQeh@=shSAK=+Cv#=Fh@1|c&^ABhym3<_vDvU?5TaYw~Iy$fH zVV$KMDt)cAi^wK=zk(&PO8M1s9xq~%jYC;2D3SYxH9q!05Pjvh?5Qp0xFyNl_EME* z3hotsNm{Xy?%kIDC%~4jv$idNww58fM10@$K2WNHHS|W?B~T%AM021xgK(xyux7CC z&Lj$oG|^ba+%c+b{e6rxDYWAYI&DidCkoPfrxV#{SaL<0x>BPEzJNrPVXF^hQV%}< zXbHm$?x{KBw>oRa_(nn&R+8u6V@yY7+N@b;rJgx*Qn%<4-x3sU?nLL-^?t*Sj9r#e z9!=glayA1Pk6d8aebKdxM;|{QvjRd6h5{=7c^+T*s*#g9(x#aiShHqa8YHy$|iFwcV85mACxn z%AFMkhAT(v5|VK(_O{u2!xKhl9MM4QO>rEjaeWove126`Qw@j4NK^TSsE*u|kyw93 zsy>aQU|jxD^?7G_;35`K1n(hq34M-y=-mx1eQt%JLPhdbl#KRdZkY+C?bYI#MtqBE zkfSC=@5Ho;}^=qp9*kdiaO-kuj2JaqkrgG^$&b z+T6lm8y|8XzC83O1!X9fdo_$J~NrGo>^G(6pBHJvT;thl!TG=1{clgRFSLY zO9{h>=vvtO@L<&UcukqnQy9eMZ^a6u>7&NU4A>;6P?=Np4O%Y=&X7c!>$G$MbG=a{ zN00?_1Iw4kPfQ>6^A~F3vZWJf9~E)Z{2`y~H^(}D;Xt7?gxD5pIX!#%}dgg7LX@m>575i>KOcI1_Iv}mlx z2h^7`8Kj|(MSzs~7g6cGv|g;=iu1QTAhTPfay6ARy0-#8ShONAu}Yv8_5JoquhT z)H!^ES8z9r!7sZr*-7ga;oQTk(Q+9KGNcfY{{q+VDrct1?jjr{FSZ#zePT{U(%?A# zK(x*m(WZEpbZ`0m*rwCkiGHp=y@j0C35V+veN147R@iAM4{m5xW)e=ontB_$GV{fV zLbM;_L?w}IbiY0Q(Dz~X*L<7^OP6Nt+smXVRj@bFzW6OQ(KuKyzPVqR#f|lx%W{+2 zWNbpI(8ZYfV47zb0?J=4dxeFmkb3IQmPZ>%u|Cc;cLnb^n&J&V0q&dCr8werr18`i zI$Jz4k2=<}ZH`AbLhQroPDIs^orLN$YWQuixyDAsrjMKpqT6AMBdg`RtG2|~5qYlHndS*-^6u zLS0^AnzPaK&hMreuk4v}$hs#y1JW2zYu?kFCsgI3d$WIjDZL$^5@7nm$ac|%>O$u; zlhc{0a8($Bitp&>qLJe*Nei?SiNqL(hUI5g7^)qWoEo*AD<8^rSXv9UyV#QLL1G<| z$0y_6DEPGZJi4mLtHY24$4ehp^I{&HBsx*GxFbPIxu2yiVNJjCf zuRy4)m&KglxUkf=(Q-RhJ@U-hPCI2JHZ+AB$SkZacjLXm*76VV<_P>?q4El%jX@cM z&t=@Cc6XSs5+zOyeY}PtNIPMPz@Z|0oN?vtyV_MnyUPqwQmnJro0K=Z4%5q!>a5sy z$eh3)lF2ov*9;`w)<}cZb+t1u_AYER#}l=cF_!C(Xh;ydknM8q>&lDMO6EG>;8cy3 z_p0s@wLFoe=7_#9Ij_F_ROPn2nLAkbv3T{F#;USd!$?p=BEmZ(%=wHA3x1(R2cwZt)Z55sWUTF&d0bwbk^3crox}#j6t3OapnivfPtzm* zc0NtbQe`WG1)?aKYs>^9HQPJsxWl)1xcr(1;Hufk?pjYq9YWbnXBx#ndW>)W72D!# z*hi4Frfi>X>&xSznsFT-i9B!y@6kN+z7c@&J*rOy(%H_215_xMm&_a-JL-B&Y{kE1 z2`_utYGv@oNcy@_zfxhY9V?6(t+AM(aaip2{D{jEB|VR|Lzm_(=fpt=#o=6lQx1w6 zOssvdDJIB6wrXLagF|z8cFSW-TD%(lpiV~mxznW*C+m|*!!L+9-DDQ zS*>1b&!4^B(iKnmEF{uLiyXX2^FZ6sJwAnBKZ0OdG{jtC z?;**v4h@Cqx?^JnL+?*tPSCfe(5nf&ERbyUm)dfHrptgASoeV~&>ADhnsS#Yz2hD{ z9Fo};>5nOLlgIMrduJLl6`RmZ6=N{~v;I!-WEGFfFq^t&yvbO*viB^k@%B~mvM!9O zxK&X)Wf|3?0jKp~*r|HSvh+cJQ$-S@9afvl(cF@Lw0YtdyNctX!1yLFuM}sq0lC*Z zUhCnoJsBKytlEb7xG3#BlaGA#i<3QwyyVsLzSasuRh!!rmYk}%W`*_?958MQ_{|O{ zP9B#?I+woQ-Dxk#^^NS30`OQmn}QOkuJdxI1(&VdT2{sNRN&0KXgTuyXq(&uJgAtdJO}y0b)=k$@)3N_#l;T<8=THo43>Yb6Zjz-G(cf z@Ww{Dnd3rjgA8F2mbbPvLVnkKxkKH{W$iA5NBT&NFIH4);r!=-aKV}s3F!Xm9IQ~8 z&d~McbQOq8^V9wLG2o|Daq!dWukW5gA;N;ccLo4|`3xNZ4gT)gb5dbt7#1T;5yyac zXh4;?T2|v8SS&I&QB|ku&9fKmCeDF%d3--ZdzYZ(=b%22 z$A`~TgqjXb^X3TvE3yYvyeP=Iapn7X^<1zJ9bA;q1gt2=r4fK1|H2UgK?t{LNhUq^ zYEv&8YZLzHJ2h;9&W()ZLxiCFW{*JfR=2sC--GD~n3!izfZO1O@H&y=^(J{KuVp_x zv&%b;rY`I~Nk8)-Ecz7>FZEMDNiii?dJki^Z#J>g1(j1KvX*l}cB!Ow*Zxy%Tca3& zQq3^_u?C?B8o|sHV(SW;?rN^;dvU*UqWchVfiV~FnVz*IZLt8>CXSwf@d5E@bhokk zOhIDWfhbgk8o82Up|n~~l!AnX!m4sSuco?5BeE)kTl%o?(IyUTLMP_(O%RS8Jw1@{ z@3FLVMb2iXtA2qa|= zNE=~0@Db$no9`aNj_*E@5Yuk!y`!HVKiVkf&MHBfKM)mxSR0(%mhx;Y?PmHyi{z83 zqYjNi>|PE28#EM%fOfe;y|2DejjQ4NApz}5jm)UNN7KN3S6||y1vdW-J^=Fx7 zC>c;QS;B*!=~K)2qfw$p(S~fgd&0<8SriP7CJcCE@K9-Nnry$ShwS9llsUB2Nnj`7 zS%_zuNvOoAM3LSjDn#v(#6bR5`T1x}%&L~hOXJLYfqP14I1RFc*71Qp=V53mOhNKM zQ_7dVX_G5@Fdfvul-@pA0ysYzy zN{7S&MsWGbj-B^x3IlTHMN@W}ri|C#rn9G;lLNRkmmhp^=qsfaWeCM{z@`%lM17Q~ zIK@C2Xbxvj7yf2?EMKD}f$Okm91(%`5p6P2g4++bcWXYcvnNtDhgx93 zYhV~1L3`+ddgo(9jD_%kv_fwxVosNY?y+hr`bOOZQ5KeUb)|iZt#u}iKLU>$LH^gu z5~nOPU#d7*Un^o%9lQf!;_UeymDWazTzG<`-hxAW?bUmeFG99H?3IQv5_7{k5Tafs zyKQaaaEX1d7dR2rN{MI2Y%~1(@i`Lo3}uZ)rr0^^Bw?AAvMr0Wo27NF_8oPL(nVAQ zMO+9YL6{1!X@a1c2YIE z>5|*-jn24`GsoT>QrVbWeK9KtC(RwnMWNnFOeS)9iP(gQDcdwIFt%YpDzeE=I`QH} zGAki?`t9mYt_`=mpX7N+K@=Q1t}pG$scrU<)|}!@^oz zhDBWY2PCXw)HI?C$dm@8qi4l|1cF``;iPY^C5kuLQVH$mn?x_Gkro7t>Yo5D-*Jb! z9!lrP03VoImk2ML{b`LXfr9?h8f3s?}vF5l-UJE^)`oUg^nw~3RTEse|@5{X4rHWvPtPR&R>y8uj;<~;b znHw?oyO{8n*s3p2fJ!Xgn>-Ck2i1EE!-xFb*W^QoR8Yn#{7R~r7~g`|tZ#HFFRS5Y z?z!i7U-w^Rzv>_|&q2^SdVU}+1Q+^lQk?Z|#EVa69@(b2h}6;|AO+#0nVfbdpnnH( zG)Vr{t3f2A{_Wy)Yu+SbK-=cs>;~PIBPV`EK@2HvE0kB)OK5VAiUu^wH6^l~$&7Pk z+YbI>&y4LFeeXb@vbn4B4B23vy7{hvFD3#j?hDzvpQ$%X5FlaaqH$k@W*89X9K%(6 zsSdlMP{ra~d@1Fr@Cuko*V3tNQILN;yKz~mC7Y8<2YMaCW8o1gG;pLdsrZFRdM&UW z>A#>F^k3{?Kkh_PkUH3?Qvq&DlifyIOKL%;P?Av{9Ow*SL+vlrL8WSjFg>UBI_~iQ z%tB;Lx#og?t17GwsW`6Qr?3N68ErOE%_5f19mDH(hohLcp zu&<0pLky}-Hu-H7Cp=y{{6P(yzCpOq(n+}3-k<4PHHDW}tjFs?8aF#V$oJsCKCy)P3B8LcKL_$GCJ>(gTTcqBz_u6n6eDU^DIbnh}DG@{z(UB2R z^UxM|*d`>>A$vF^HOzuIJuj&1t6HzgkF%7(j0NDj&l-u#w|RJCi$;7Oyb1$;`7l&S zv{5ZuD0QI!sOGbBtYeA$kp4bjp(S+mlLF9dbS-2!K?0g0Vx3yl^3bLFb15UJF3jIj zd}a_4k{J`T(l1KaKj--QT}&;=K|~Ec{W@l)*`7uZMX`{?N+!U;397=aG85&TT+*`I z66I63#U|OaX-!VP(ian0K6k1(Okb^X>8s!YrgS{nXsZ7xKVpZk$~X$sziPlQ50_CnCq=S)vUuO5&b z`4kCIR2$1B{h_)u;F2vKQ3KUV&YBl@?lAomQ~Ysns2=kNcYYYmkfz4RlR2{-59AhV z8r%Urfo6&*Iir;`an@^yEVKoAh!Y4vE8qs_wc%--* z+ZaqOy%loM@mj4yCU9AedvR0PiNa)|IqrPv12JamLY;IVwF1{D1HM5)w74w+&DOR* zi&kP|e{7sIB6bcixbd7%G?6Ia}H2gg-c@PoeDrYsFq zfp@WtC`wyW>y0Xd= z$HmKjQM3x52Jv`NPD;HP7Rkbw_&NqE36bSq2r^7VG9M>YVCU^5?uCn|zO;?mSk1nK z)K`MB!mVwOA&j>h><{ACVyqI%xS+$=7CYK;(Ay5$U#u>M*mnOWp9emm#y! zCNZR~fku@u$VBeOKD2J)EeAsAC|0sIydVly?G`N=y+`48O+h7BVjPc>WLsRPhYX<< z6{olYK`%g_=e~53HQh9IH;bBD!2wlWNFAe)o<=JcAzc$KA^Fbw+BiFX#?-zQ%4%Rz zk8Sv}I60Vn)jjD6u+VZ`An{dCjO7Iyf*ML``P*rSuggv~@5YLcq^~_1Y$;Uz=QJTS z7Otm=*eJ!lVv@f~hXplKe-o)m-Fc4LwD{IARftkuLoh^}QaYktBtmzVeFq|r$6n}D z@Jm=RsD(GSpC1!lBdNU(&cGaSoHn&d`u`BqnFxyGm!!ZLxdV z2Ty%WKy@y%s*{A@XXPs1PnHFt*+ENINFwuv?VW=pR%11kz%g_IacH|eTe+LQXsL@L zkK2s~YV7=~Lkh+_9NL)*Oiv8fi|<1+O-fSRRq82QPXI@;=LD+fd769SSc1lCX%IS! zULfMf%TrlvwHO&AnWpRH25@C^r^IWMM#z;w=p1HVC9hBA7;;p%R=PL^R&_bf+Zq$4 zR}_1@-1S6n^Ip4<*Nv#f$Mj4zq9O0f@%qz;+)iz>H=;EZ(4NoqhG*s*G(cNX%Ph^w z!C09!2XYq8->9YDKMutmkS11Sj^Krtk2Q8}y|fu`iiO6#&St4IsIRT?IYdt^v)Wt2 zOkLKGg6;}jeCDuVs@l`=h7c-NPPz!+yC>VS0TMw3B)Jf~Tle9rO<_cnMSwKaM`9+4 zjF1Jy;KOg6;7Ht5gO7kn3PHgpkeJAkl*HLwNR|lxn;Bw~G=hYc)H3c|IO4a26@MA^ z__i4(wStopN@sEI*a8$v9jP&(o|52s!PG))m=sE>Ou~V|(CewjrB;m0d?Mu|MLk`8 zJ{w$Vx6Ke2#rCTRqmTV)Kb2ooA*T}kgropLcm?kKqBvJ1=eD{QMP zd%MeZ4Ty6g`5JM*6KC3OLs$pc8l_VqN>9^DC}S-SxSi;6)RZ9!3P_#BWP^edok8nK z!6+@Ou+>yDZA(hP1t7#8^~GBbVz#k4nD`oLW0@mDdeVa65H_StN2j|;Q%F&5l)?3y zbS#p!oi*~-Pdefo8kisuK6H35xpKwAM-jCu0qUKp1%Qidlo_AJOx(6Z1(wTi&bX@} zX(~w=Cz3bnCZD0G@f6-joaA zBn-)`rIJa={Hwz3!f_6`G591>d#02Izhn(TthCcg#R#8RsCqjP4un(Q&cbc31g$H| z1d19)@Je->WJHlk7R^5O;&S1J^^WuDNnPnmRvYIkR0A8*eVOhiE(EUul0c61LD&uE zlIu=@u&sG^rE#4BEvgcLccX?5Hs1Xrii%X&R!Ks_ftpBCb8sD0bP=%qDZrpow-B5G zOz43hp{r$OlTIWk6o`RLCJN2Muu+=Dq~hCV2g17wj`Vs1K@&vu0x44vXykOE8l>0m z&)E6i5KT4@1)lhC6Hh+iblK zGC5FkH`WM>-n)u@_RLO%_7+<0r%u0280gul^!BoS|F_sEfhZO zDPTw=3XmiGR5Xs!TBOZuq*U$& zNd$@my&Vp;BcimdPnKvwfd&WIc=)KPy@cEltzi_dJ2_5CFum~)|QnW64E2K z)PLBZbmsa}opt0OCtq0WPQ7-*${Gds6+4ZfkF9Y+q;K)9bfcN8)VQ=2T*FPLsWJ@w zwXBff+Y{VV_lX&$C)wH~qs}nT$^sVQWoaW-B#K(u z3wF^$RB`4_E0HTyWP+6ehpk(sTr|3n3M25+S!r@f0F@^!MLMs$&GUj<2%nfA#+y;* zIC7ExRkT-|N(u^yrQ$&8UU?K1k@-}JKN=myK;EOuK%q+mc&oO+Cl^=K)Y6%CwIC~0 z#@{+Z@qiRkHQKMai49vMKsr>`Dmc!@io8MPI$PySWI%GbHpoj)HUcmN$WO+<4Fl7? zDb<`%-m^}umAckvhNG|_3R`WweW~RnHr#GiOxw3OvNi5iywRQ2K|~m&>AC#tnh)@z z+I7!U8^w5SvfOQ_47fs%)8jy6&H8#!Y%DWsb+~Q@Id&bW@LfPqKxrCzJ5j#6xDGOo zRjx4t4?)(EezAgf0F6^nRTLA%-)d19ks_q(C`O>~S{jg|CbF8JP5V|wFn?Ot?*25N zg+)Jb!V;x2qbN|2>E66a+(CDX#85w_Tj;pCM8xQ-@kKTcEFd$c(Wmj6j;4|cpM^t! zFYxyitQrlK!THo3W`aTG-k!8*rPmRU7a3O#C;0kR`dC($P7tREl=rJ^@=)VQU@feA zDE|PZ9CzM;KJ`IRGf}ysqs#!LF~K)iOKDH7v?wPA)7DL0IKxT>Yi-tZAeQm}0A_@@ z;s)WxbC&KCioP0Y6~gwQgs+A^MkJ{0p{MytUC>S;wr(%nk;RQF$%a^2gWO@8?5s*`(bJx3LN>2{Ydad<`)TatsO z)SrbKy3$^Q9Zj{1t`+PfrB$PF~?zc9Pc#32n2Q)3L~Q zrMquwVS5FT$6)|!OP}&JhC~IHx@Z|1XT9rRqJRb+P>pyjo`qpe% z+cBQMvH*(H9K8ijR2fss4wTAu^9onqst(%Holg&l@2-T_t)_aEQgY%gCsH@7`=tZH zRl&qfBm5)mT8P+E9>D>5NKqAXjp1O8Aty@ltw9beA@mi@E1Ogey(lYW4d#a#NKOjL zE>58L75CRSH`ab^T>=&Jis2&F^S}_MQISGpx@nW)-U#*EPa`BL4d5f;(xiA(HmTq5 zpGf^G-#5bx-R^ECUk=ra-Ya(ZUPOe)r3E?d*DH;d{C<@UcwM>=FB(frm3`Ws?tWC?iQC>= zD%Hk<7Pgm`L{8sICh6BMm(clit9yo%ah0;OPEPfRrpPHNOyrtL*q)U>UzKa2t73;V zZAXYtXr!rW$fwG#wFV&8um?($tzO(FLyB*Nh}r2K_ud_5N$YM0ihKUxU3v0;=Tbbf{B` z5_(dzDoa^%;HWkCn603l$>vQQQ}7+>T!032PdAv>2yCNGS4HPtD^@tq;T02*K9%iH zv=p}ajCqI}(#8ebYnI{_)|B zJ$u%k?CAyIO7h)5;ipR z4?cKz0HJt@IP7UHg#Ft|2}!JDTGEp#Ft4RYPL+@6-i`?aEyYNIKyMYON`Tf4GSQOV zq7qPJxZ0Pr$4Y4TK>a?&Hs$5BFAj@v1s|0Vi-ufS-W9b33$?VLe6&DU){VuliQe9> z-32S+=sCJiTApb>m92RRGgGFO0PG~xFC_!%Q3_Zhx`Katol-zjlG=(=b;!ol>eX!J zvx^*L=7QOY=~mYF{{Vk#*^`nbc=Q5_UTMZybg5z2X94R!I<>QSG81aw3nUW@gS}`K z@tROcqZ*SynVKmHR-NlqfPHkWTugjLVA|VXSS#9(J9zv?l#sWYgSR-Z+p&kMg0;M~ z*GS_gq+eNq>hf!w#cUHD-Wt}vG2!E{?L`6n^JD)2+PQp5w@EP#rAj{bdLpXE)@szP)%r}BwK17U!gIR5}ByrzEj0FR)gvwP#Vmk1x;Gp%Irh%nW{ zIU5Z$PdBWZ33K^bd|O6=E!#E(cGjoj@}F@UQvj@`EVWE`Xn< zO8i^L9PgdkG&q-@GGnK=xIaoxNxM?}s_n&3){ZuV9n~u*TIaN6#=C7rL&F zIxH>5BOlI{#PI89TTi&O0x3(fXZgLs(X~Tr^eM0AU*^vXzH-@kJ;r4ur>~o>YvA4( zl18I~3|2qQUL1OVS4a6*FNau<{{RYV2Z!0m{j}0jZ7q~fP8g?DE^ONvEhq?PaV129 z>!f~FYT7Ss!|i|`bynceE?%8QF>L@K9NGx;fev_FUkAi80uYS`+P^oU2lLs8(|G&tXk8NWSQEMxX`^=H)k z)@xRdBh*Itu86LY&Z|f!?ZUtM!8V329y5+&G2y5J_unmcyTltlRru6+GjA7e# zd<0^Fve3)<2$~&}KNd^$6>aMtF$61c>SRRgGX|zb(+^V?CcGlchC^6|$hBLfQk3rr zjGHVOGc*xvT};!{Ct=*MkCe5U-ktbBq_h&(w80jX66_o+FT{1QhW>AvMS(v9nMX=( zw*LU-%j$8Q)4%ZsTzpt6hw4&7^#y#ULC9eXaXVuuvWgY%V}@(iYN{HG9RmLV@*r9l zRWQGLMV{l({$~xErBukgwCZ5x95*bD3vaY+Nr0rH zgKKka#(Z3DCZ$O%3zXh0c#F;bbauhp*O^$3iuUe&rV;TkP2MaIgk?A_`xUtNuJA}; z^2#-~y))K{4ol9P!~&Fj^$Y=5RhBgeh%+`MST>hdhi0HVd3E1ej8anY5@ZgOkKz!= znECTAbPBHRzw>e85Z5b}*{ms4Daad{dih{ro(CjCfvsVWi26l`cZ#6Ha~;NoX69eH zf^6ODraxFr8O6j?YW;`)>Jk*!Hn{C@Po~s5w8JJq$`~IpkXuHz(A`dMnhUHxWf7fE zZ_9~804>eq^#h6Ni$IO!3~J*+oD1DMxN;p}(T=^qV})$o--S+Fa3-4zyz^0Oo`?rh zzMBUbJBKl1obTola#Oay$SpJnRJ9bV4yvova@>{PZnY~w^c}|9aElmHt4KL>F9ptx zvspLl3#k`y&#;vXuI^tXsy^+?DE?Y6xL;VhMjtZgJRC*11@{9I78T4jN(H(g?M>!a zL@B3_aZR>_uQm~B*Nzc4a^Sy(e$KxPcV4*>g@OyFB-3c-eJs3JHy z%W!}f$-+2&U-2;v>%yuxlG;53xEWO9af_2{FMO>;b-%`Ata|5#dLBw&yYBpiSw%vV%vahx`02Ml2bo;M&2Q!!@P{+G`yeUG$DuR%29|)}-F4d- zXrV>)^>NH1BFSOwKbYx$HoI`Ih(=1UU*x%o1-}hLl(1}Ya~{;wl~u|>1KP&;F}4VO zW?C3ob%_B;tx>Zt8~&8eQoty6B)OwH5O$$jR}&L%S7LCszfy@}jTu9C!yV#L3D4>k z@pwSxYtMYa!Zt9E2pY~u!8ez2#wE)QtNe(-D_Cmaa)Djc!aNG)cl=DKMH<(WQHJW8 zcDu?DzZ!2CyQuth?KyWv${ye?Wb}>EpNsf_#@`FAlZH2BkH{*^8@~ArepnSja%)HT z1tFm2Ue;YPoi+8>tDFLaw=|~_g`l?E3YG3oAC3CSF*R2T(D!T_a5;j~-&HTT+kjnt z@8&g61Bh7(g%cg%r|L;tS#{vlj*o%iu?% zvAsle@peiY;aO?_01??JtprsvraOB-$w60&&xmp^htWN}H6`7GBw+ZnH$BMq|eARi+r zcN2w(v)*@3yB0eT3yDLa9K+AcEEEYrYX%77hY>E1lOGo|HgHI69IO(`VI?wJ7P!_j zE0`mhgl-oAXoA{|wia@ITyW+Xc#T5$gO~-m4kl5K1*Ulh2V@1XR!8_ojodX_hh0IM z!~hm0;~#SiZzs$gxJ_gp2tJLQ(Gf`zXw+ZX97-!3h_nD4=5ly8R4cSyGgAEc`XWqD zC9rKU_9SENg6Oi_7!F{WZ@q2?_awA9J!kGUEO}HGkD-A{CaKl?CY2=O%HPW!PPz?O z`zdAw9HF*g&vNOAZU-=Rj-fn?BAd^aa1z4`JXOkhc!h@5ZQT_+1M(7;W3b|{;vc_R znRK?>^9==*GXq-IZhXqWfK+I&K#+OHB2#{W@zg0z-IT$#y{)oP>1H_UFpBmf{YkZV zFLqn$1%iu(L1Moq$X3wSC?E<4X5tRxq37J{+tpU$8X#k7X1*e|SQ~~xs=0(Wg%6O# z#he%4JBC9&hbX#m%fj4)9l}3|Mcf<|8DQL6zD4xDrt9g>3vc64g3Kd9z4GCPkVk_Q zuV&I@F(}#7*%=!#F3^Q#8v_Q!cy4hv!Z!VfgaY=P6^pvC_?k|T9jM*YsbGa;agv@l z6td^M{FVTwso1yWZW32couU1qvEr#oY_J(dO2F8oc8@8e(P7R&Ib*UNfh(czV)tl- zbX>-?+FEtjbC+_XhU|VPe3lM*C(N+bK&Uodbpv4ghUB+KIh&w9ub9@fQ=TdSOCv+W z_cJaic{ zrqxJzn}3lZ@5gB>2(NeH0 zmX;r0x^WVx0koJeifra#!tTU8^bn*NQCee~RJV9n+#KJEx5Oq+Y=rQX;uXVywV{6$AG6(G8HA_eijI@Y8ot6DVs+IxAs0bc~0a~2f z{7xNY_{^Rkg*R3?b#pi@ulx+) z&HyGKHVmn;Vz@?&y|DO}5{^H6#}d?MgD$FynH&5i)S-2TWt>48N&#}ilL)ssoihVU zH)zeX=L!~IQprW0Tet_u@|M8E!2t6ipG;J*x5S}kQ4+Wei5E_93Wlw7tnoLouvYxk zprj$Y=#!W)b92C@v*?9gWWrgjZ!!8ZmBN?7xF@kp9#Yo{=mWy|MLy<#Q@rdsc^EGd zrQ}InQY{w%jyI%z$I6E}F5`sYhYmQua1FW3tw0+V%Qq_rb_}wLfl(Pq&}j7viMv6r z5MiiKUs0@$kdoD7)BVo_Z5ZWq4p|nbEqH~(qRW;40KCD-RZIh#BWDn>R^w57j`}yA zqU|fyIUM3wiW{dLq2C|m(lMpx>)>gaFhL>$@5F0>iNS2tK!o6c0=Ll$Roh+8TP7Gnfo1YRXt2k|U(YGmse}zSQXlKj}Xw) zBBv$TOX3)pL&_24G290TwKZ5rhcE(}GwvNI%=ypkd;FPiYJR}@gbjP<%Kd_@MB4JP z`G8~xxS78&AX3oca|?kmiEb&^nCSq?VL_VaUJ_~CG-t=G#?P!z^!2yk6rx40?@o(+#ygaW4n<>rfaxiH7vG~^;Ohi}EQD9*aHGc=hPioRd zte%;#8uPgFy;QDWv%#j0KiP)RAiN)h2Qa9GF^qd7j4blxyYmp>n&ch6V4NTw#(N=} znpgIMHv?V7#zu~y37ZSzqhtd*g7T#@&sQvPNyG*og67m{>cZ zcv|0FCAu8hf ziY8+u5{hzpxWol^0bG5@D9ha(TR2L@Yzsz|6vmO=;C7ir23}!ZMg5f)FOIpHWiqI~ zJmv%(w$wa^;n-yeD3_N0GYutW^78kX1+;9b%AV#VumlRx>C5RC5}{*Ta?>bjS)%q! zqrlTp`8hy~XxTX#PM#)85)^a8ms zMq;S9%Wf>Tyh}dBcrO0{xq%{BO2z)ih_T>;S@Gg zUtSXcB;I~yP_|arN;;`iHh|xP--tjM%WY)5s?2nx6bK5gVrWeRxUDyw(QH)?<70vI z8;fM>x-sSbP4})1eUS9U49K*2XXZHYFuGqB>zF&aVZh&EzW3gs0{*3xFVI5rb^id8;uI>8F;3***Ffs!SThd61lM*%PYEsVs@>_Fu^{A2 za7XZ5FRj<_feYe$H-uqk1K5d2_1@JIj1a+-Dv_ z3NCEH#JC#pthPl4G#5E9F|wmMe*vjPmP$~+9ZF|rFzvo8#9wv)08B#JbS(K}UlXiT z4e(*8Qu1BJJT=S}70|j4Sd2+jn)-(Os&fILG}E8iisSw5y7-$I5f$4}HAGv}Fjffb za7CQT;CmG?$Z|G-$EB66A85b;22;X<_y_YFKXn1iLG~=a+>aH|Ot}0rFazD~+EZ@S z7*N$!BvWufzYz)6=Jdvh{%w8)EaVsNjIF}N^_a4_PxyJh(Yl-HEdsg`YW1(YlV85CRyv>OMN{w7CX;RU?rB|@={ zxUx4_%H3kU#aC<@%&XQQ5TW(y98vm!>q!`V zBEn^jS>;>91A4wz^-;tSd?pG2OM$dQh|&u<7U);Z5SFrItxO^vZQR6=aN-OAdY3$^ zydn6zg{wB3Az%72DA_LNT)TAC!!?2-oUAL{!cGP-+OFlqZuPKml%>^TTN)G}v6KZc z2%m{4Ec0xB68;8(r=d|ayF^~`daa&iqU$A|Vzmwil{|Z+v=cP=gmx3P)y;W8S``%&ka0Cp1Jro&Oh)`ojE6QT@4Jsn51*zbElLtyb%Ht6#jMvd$v zkJ!6Z0IOm;^0V~`P@w*t48xuDGTXqi;K2+H@Rwbtf}^u!rKxTXE&{4b>E2hO-!qL< z%RV-L*@)_oH4!TN7`VKwX=((Z5sKzAI-!xrh>@sui+s(= z-fs1mAJm~@rpJCWc!!O+dx0_4~Gc)^+$?uENWfG(T1Yd#Bcjb#vE~T*yv{Hs{07eJNeXc6qmrX9GnvdcO zDk3XAN2^@UAOSv@g%_3_`GxA|OBQ(+HLs|#8+2BOh`D^<)Ncau79Wc97h9kFH47!5>r^_oH zu*e}6TsK|+0M`gAo)CV)`ZZtvv{p9gT*1>dBbvhI8c6zTJb}==RvHK{o z>+HH|;z*I_l!_{_{{a8Q03{Fs0RRF50s;a90RaI3000015da}EK~Z6GfsvuH!SK=H z@&DQY2mt{A0Y4BLIZmG6j5*k?4_CwPWh+%L2UKpacu*i3Ce}sQU1JN8qS>r35?LEzz>bc+F=3nO_knrv9*DMw9WSV03YE0wcKG&Cz^hn`3;9gVfZz z>o1g`BATFa(evs8t!XTf?QY)&3U1#xNFW}RERXY_QX-kp9~#a8>`B_U&a&@|*L)s( z!rBo+yE?^eLpUEVjBr?B;7)yfXLzI9HjiW8ZTlq;bB$%%=&b0}e1ylWagAHf4R3ZB zMvaKF46dcRLyhZm@uJs}29G>BYpjIRkXbjW;|WeR6dhS^nsFdhA~4W#HFJ_x2O<9eTmexKl_)j;0Ba;n&8Qw|4Ry6YQZEIf0AD+=;A&bs|fTi}YE8{+H7d9JW7 z1S9qF-NQX9i*4)HKtL2LzdkX8v;)*N&v-(4qbYh#cZnnlFK(Rp!nFuG9j{^TVYPFH zv9M{T*!ajp<0KQXA&iCA_}+A{9z&OZSOG&1b(}t20@VN@9ydFV01=YtJ>NdPV{q>| z_xSc@qzUp&>IZ7$I;wRd-%OhGj1`Na0`dcneK@hoS_5@FU*`sjk`FQfh?l+dk7~UU zM51g&*v+g11=5Z;9z10Ql?kA4YI5a*+S;W`Izx;YVs1e;cqcxwBP|+{PX>VPnN-3S z2;Lu@9l<#WJXejGq=KCXk>SW{fcPIdsX8b_{$!yCjB*C}+`s_=0QMmKVn7OTX;+W- zvc^J2lcU}aOc+EazwS>a&Kd#HqHDfzP=e_Eyg2=4`dteMDfmWkn$kyn1mnjpMX(VF z9QoDG%z*%29wpC#kVRdm-VG{L=QEm~IzBO|tzebG`+hKzL}ZN{J=cxrHdMsSgK!B= zn>xZt0I?DhKjpwAh)0Y&x{bnKIjz9!sMu67p>h!v z3D(bS>fvt*WJ4D~dn=GRfk2>6mv55;#S%skQlhpZ1?bJ|lK25nY29~RMAe{3PX;J? zGz**31T6%ho3PwF#K;j49vmp!kgX(#TL1yyc&6b6!m7YFh-R?im`-U;ogxjCPgrs| zAsr1bk-}9-DdbSvUyKy=8sVsWc=^c@8vSE?_lEC%4}+V4nk24QZcZ^s1s$aMJ^uiC zq!n5_5p~}3&;wL9h5j&fir9eYKfHJVkrn~kVFd(EQ~mv7fD2%O2A=j|U@74M-{YJ@ zszgx|S+5rc0IVb}*NAUeUF&;Sli&Vkp$SF?Z_V#DK6bEKr}u{0C|eGf)2XbSiXoGzWsXx|~zQ&_~9AZ#!fyeHivtFqfwA<=bR2;k}7 z4$V=#WGKLz9zrtN>l0+*>M0;FQ1oG1hE$t$JJ#o{8PXp}o5N7M!H-o0j4@*O$KEEC z4lG1c!OTcD`=xazfsdcOh|_|i)OIu1)w~2-1^8)zZu_)gm#XEJLz*$ z=5lOCfDMM(jv++A*}M?(g$A7x!JvnHS%RvXBX?2TNv?ISa;YG9N`(*&rT+l9+N?#% zA!3C(+I(RGFoYlRAGa<5pAHf{v4bcW zNU@O$TL#O3D7(wDh*$Li$}$8s2FmxO%21FWaQ<-nX3*L?C%^o}FxDLCG03nYqek?7 zf3=;{dm)C{FmS-TPI3ffx60^y?;NwT;BUif@x;yrU9~!C?^Cx)Vp|6aKS)%n-gU=LCQjJ}IydI{3n)!@&}@ZQdC>Ma#GY^yDvI z6D_3RRd^2ZqpfD8D@0AFNF(axv>@oVkEOm$yB%c;fCYs$W6In_ETIo_?bw?n@Zlm> zQq2gV0Bffi0TAxtQgC!{6vRO@1Fj2B+^erW;DJGHb5A_1IfYm;0Q+nqrh?5_3tm3Ya0L@P%!s}#UELxHe%_k z$3>*0O+*G{0b#c_VQx)VzK7t*bn}3%Z$YO92m-gNWnM1+yUC!BP_8sNHHiDB&=Q0M zc<>NhiegfE3qEU3TAYqD6-Yh`SM`rt5KRcE{Q1Rvi+Dfuz~rL@i=lPy=G@x0sLOrK zopuzr#9lEcsYE>t=HB*ZB6R$w%k1ff(+Vv&LRM;J}}ZX9T4R{_`Wf46};+` z?-7z5HAeirYY(JIDi4v~1SS#$CiOqlH&qSk91QFMjVAJ<9DE3|{AXAcymW;kzbA|x zD)Ewn0D#6{F-PWM*FlY;x4w=rTuK)5-YSolHxy=&A;<8ba{ljt{qxTO^4*rf=(3do5=lzXy zaVzc8n_A&$3AVq{0<(T{L?H&c(7iJ^l?qUask%fanU(;MG=iZ34n_IEA?k~pr&Jq{ zE(mj{g^g^8HGp_-1=H*!tBTr$>DqowFAUXeozM4Jfr%WDJp0~ChGj|Z{{S-eWCO9| zee;~8K-ke|tP^U7Tmy2Ypab3mRM7N!*SrE3?h&mIkKQ#AD}@L=vg;-=GoIVVk-Rqs zpb4h;#x=q&0R^|&?-2wT0H`Ci3A0%}Vxkaof&=-uq_={!HvNBjvGC}M3**;$QPBV{NzLKfak ziX-G=NDvF6-Wqaj%8(Tm(cTFuqjG`W1*tY*Ru{jLCiGlR#tAQUQ?&<|BM>We%7hpz zDZJ|lHxIGFU&b*42_8FNfAbi~GrLR928;nkjdao~rPi{MK})Y?@*CNKAcS~-^uJjk z;QO5ikVd<>U?SzcK`Ei_)Q=?@YEZ!9hNrD)?2a&u>$rW4&K}#K(fRapYNJMP3)!lA6pl(pA(si%- zlh6aOwMp}L;}xS_XH=($e&~d7WOizo0!OgV*{{UQ}7A_1Qn~ZY02s3(~@DxCeB|6frJtO)t=s_FK^F#%wHeI&eF9#SM zAYGbv>7l1HG)NE>5`FW|M_2$00s+Ot5ZaBSnSm(*DGi0Xrv|y!E=rM%EiB~m&IJ?{ zp%@0#n)8-pZAu>Whu?Ta5PA-c`26Bo0er!BBz%FoQvg^#vbE^Y1NzJTH2MY2?n8Li zYJts#cl3Q>@@RK~1=C)C=3X#4^q>yR9;PVBsw^FykbPlsk8YWbX z8_4AYe<@*xUCPGy*_sdl4u#z?m0Jf^QRu`)1e9RU=y8X9myEe6XApw;vpZyB(MzN$ zYvUI%gto3)gZL+mc$~pW71$BEzpNTUCWZ|*;Dr^Qao^z0S~@E2-&s&4r23<8CNtC{ z;LEUh0?aibEjMpkE6xg)BXy_KHB!}C(TyBjv42q}uUQIZ{{W(77I89C69qjf0+GN7 z({HNZC+{AGgzh|@jBkt*V^ZNfzi#tq^{ipuVY(ypy*{&CX%c9CFs(owb)C6x!tAEE zR?OnSCl-r`Dt}nSDqtOJEffjsDA4tqC}<4)8q$FF!G z4x!Lgnx94u@^AyYjy{mE3x-1nK)SE2wAfJm4!FO0=PK5({eE+z;CF@%pcPH9_8Q^i z8&r6pJ6-Z|Gd`qgQNe5|@5XPCL6k<0H#}=JUJ*pKRo>0*ywnM}mK84SfAb3y8OjuB zUuBrQ2ZU!Yho3mkAjEBJYP4=7=4CJu@ZV)&bkI|* zBp_M_j#FFw;>3!nh#ExQj?C4CqMpl6mtxtPYAmqRw2c$t!^jKG_!ZZ2*BZ$Y01&3w zfP!CHEJ4(qJ@WI45ft5-M-C&%Cs&&Hm&tqTg72$!^k6Eu0IKdgQVKhKWLN+#ATr?H zgTg#GRK+9#6&wS7b((e^?FA$9L~%Ec+?0kAaEBz0@rcrdu+W#l zmB*40@mh-1^>9H0R3bxxC7%7|s8Mo`&4Cd? zQ6p224k1J-Bi8xfj87pRa*vjA)ejWl3@It*HeU^qT9Z$T!?)O}&J1S=0|=bu?H zDjjV-2e$MC29%N5q|Qq zD=lHEee;}}mPif{mnYXCG~WlTZ89E7!+&!G(2!6LE1ll5g{tK$qu$-c$2l080M@pM z^k)kfdqF0bCTqnG&O+v^YlyIgb|eTI;90(~){{msooW&P0GP-Ea}do$7uz?nus{u; zsqYW?UNG#Man*)zZP`Fn?vL5q=xJ5s?40xFu5k;|suNX_R6)EY>*a2LGIQz~? z6jT*4utgE54V5YY7kSHU9S2G!>r~#dm{(HAy>-{)D8ctMb$H`=5H#XMzfEt<$JtT{ zDbh#J1}Z>!Zz-Nh#-BN0PhCewa>?QKmCONCXu@(qd&ok0ApjZ>kSV@$EUj?>Y;eW< z#Z!Z&CDhv7Jd+H7_rNxAxDt?abRMI0JKLHFuLB)78b`!@=OzWg}K*l zqb@sOUDECrTFnl6Rbb9CO7<@(J;rO`*o1GL(dC zctCQg*vw%~o;0p}4d+yrL_r*F=5kCfKo^0W-+MU7ktoO} zzf8x1P)JvRYY0^w6?6q5M5>B)K3wt0o}dTnVvVqPRO|iYA`&IA4aYn2?-boo;*UW< z-;5Pu;5w+A5iBsI0@^9L-R~&#U8aHq{{_!m$$M7AJn?5m+ zBd{5ubfG?+VSzS4o+in@7OqC7Wne2Jmu{-*#<@qt$8lCEHnbUC+-a);Su}%Vd%`?I zP?k1G4XWZdksKNb!vUr`HF+dH$0d3#E)^JjM37U`vSUPZ~5}dEi z>Gg(EQd6}Z@$g_tr#h&0@@q7}PYrwZHHuIGNUu!)09=h^*Ag*jjpm3cTh5GOJeS5c z8p>uWY*oBn=7R+BiI;>v1|<&F`;XK6*MV0ulI=xMMRp@ezk~UWE3T@urG`>bwLewae%O@@E0^h{+(jHhhnWJG*Jj@ z2^2(YLYJ%x(h=6673(--BJ%S1m;eCPp!DlkcRI;H8>eIH?7n^x0exe=dHz>`Nj^N<8FGK;N+bvJR7w&#^F0Gbbj3pqht zn1KzNYoW$)0IvX0(IfHT0=kH2O_`^R7GhNyr7J_lJavpwS_$VJBsHPYF$4!4En4CB z!QiG231ChG3N+H_qt|=KmBC;%U0ZeR#sf>SMywW)y zAm9VZ=cfn4$^|%{D8yE4SJx*ytajb!eoTDF&w;!)hs&&S zP=d>5cgcL=+>r1GfjRS+8WCDgmw$LVD;A%PelYkT6{m?Q@^6k1B}%y)AZXB8Fa!V( z>roE-!oZjH=_fH{+k2638mp7$%hWF=&lcDPPEbv1oxt2-LqAOT$m##4||{ zZ^ir8Sx{1PJ>e`6Mfg zePnX!$Z!DFjo(}Ajed|gTnZu(d%>v4FtCm5WWkESBDGu%km~O zL8cGE-Qx)P6 zruf5Au!=7=_|A4B2*XvV{{ZGsi1Mb%$nZTUtP`jo+|Y{Ok-hE+2;M~z#0A_Y2uQ~B z4(o3D#ux{iTl&C_fMY>BZoczQ4G81%{NPc}ciCi#2*XI|U4udcg_{Z_E#V&w$H=2Av z>k|fCBx$HleFhqqw6KEbr_-5x;oVS^^E8yD9YW*r3q1Mi^U@rA5|LR~`e{9t~h zBTm)k{CALmD*7Nt_SkX ztvvXo3c{;^)Igz`ArokkOxzVb%x1p@CX8p^5F9KjC!ZM#8-Vk5-zM^@V09Yb$Na#9 zXbumrUUA)s-CQ{9`@%;`O_jng-}+z>fFa2(3LVRfHSL6uMwEv>^_DnZXGc+$2H%m# zU=PedX>e&2L6%5}D^Q4t+Vi6;D zZCMsj6x`Yy#XiBw7lVFx^58D$ORs~&-;)Le!C)6#-itC1y@q(;-_^ii@S|Qq#r$kI z$AnmLnxL7Y%+Uc+pl|O=#VZQ!p_?fa9i}r?=HWn7&^m%|=H*ejR2o)QtIi`(TG4>b zSTdE!y(;O3l}?NfuaL)zF<3|iaBzQ|rGm?hNZOt|GAsI~5aVDv=KbWAM1X}-#`>Hd z?$8vyolK%gkgpikFecrz2S`qe`OaXp7Euw2AYK8x$Cz0129%SNkN)J%3fqK|GCHp_ zX7QB-sRc235Fr{cM7*fb98G{5z%Y{lQh)+TT3+ThNm3Ll)CYQo7pq{YqL!*UO%E&< zni?WUgzK&TbAB&`>LBDz#(LuzTDW*|Bpmat^_)PBu$))(@s~?CPBp5wl(B`lSS6?( z00`>$dBGw5yO&2^xYiM#SlHR_o41#(W}^^#8V-I6^>F@lRUl9f9YC4C;=dvWuy9?{ z>SUsb`2a~r)Yd36t}vT T): T = { + val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.get(flagName)) + hadoopConf.set(flagName, value.toString) + try f finally { + old match { + case Some(v) => hadoopConf.set(flagName, v) + case None => hadoopConf.unset(flagName) + } + } + } +} + +/** + * Filter that allows loading a fraction of HDFS files. + */ +private class SamplePathFilter extends Configured with PathFilter { + val random = new Random() + + // Ratio of files to be read from disk + var sampleRatio: Double = 1 + + override def setConf(conf: Configuration): Unit = { + if (conf != null) { + sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) + val seed = conf.getLong(SamplePathFilter.seedParam, 0) + random.setSeed(seed) + } + } + + override def accept(path: Path): Boolean = { + // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead + !SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio + } +} + +private object SamplePathFilter { + val ratioParam = "sampleRatio" + val seedParam = "seed" + + def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" + + /** + * Sets the HDFS PathFilter flag and then restores it. + * Only applies the filter if sampleRatio is less than 1. + * + * @param sampleRatio Fraction of the files that the filter picks + * @param spark Existing Spark session + * @param seed Random number seed + * @param f The function to evaluate after setting the flag + * @return Returns the evaluation result T of the function + */ + def withPathFilter[T]( + sampleRatio: Double, + spark: SparkSession, + seed: Long)(f: => T): T = { + val sampleImages = sampleRatio < 1 + if (sampleImages) { + val flagName = FileInputFormat.PATHFILTER_CLASS + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.getClass(flagName, null)) + hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) + hadoopConf.setLong(SamplePathFilter.seedParam, seed) + hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter]) + try f finally { + hadoopConf.unset(SamplePathFilter.ratioParam) + hadoopConf.unset(SamplePathFilter.seedParam) + old match { + case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) + case None => hadoopConf.unset(flagName) + } + } + } else { + f + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala new file mode 100644 index 0000000000000..f7850b238465b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.awt.Color +import java.awt.color.ColorSpace +import java.io.ByteArrayInputStream +import javax.imageio.ImageIO + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Defines the image schema and methods to read and manipulate images. + */ +@Experimental +@Since("2.3.0") +object ImageSchema { + + val undefinedImageType = "Undefined" + + /** + * (Scala-specific) OpenCV type mapping supported + */ + val ocvTypes: Map[String, Int] = Map( + undefinedImageType -> -1, + "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 + ) + + /** + * (Java-specific) OpenCV type mapping supported + */ + val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava + + /** + * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) + */ + val columnSchema = StructType( + StructField("origin", StringType, true) :: + StructField("height", IntegerType, false) :: + StructField("width", IntegerType, false) :: + StructField("nChannels", IntegerType, false) :: + // OpenCV-compatible type: CV_8UC3 in most cases + StructField("mode", IntegerType, false) :: + // Bytes in OpenCV-compatible order: row-wise BGR in most cases + StructField("data", BinaryType, false) :: Nil) + + val imageFields: Array[String] = columnSchema.fieldNames + + /** + * DataFrame with a single column of images named "image" (nullable) + */ + val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) + + /** + * Gets the origin of the image + * + * @return The origin of the image + */ + def getOrigin(row: Row): String = row.getString(0) + + /** + * Gets the height of the image + * + * @return The height of the image + */ + def getHeight(row: Row): Int = row.getInt(1) + + /** + * Gets the width of the image + * + * @return The width of the image + */ + def getWidth(row: Row): Int = row.getInt(2) + + /** + * Gets the number of channels in the image + * + * @return The number of channels in the image + */ + def getNChannels(row: Row): Int = row.getInt(3) + + /** + * Gets the OpenCV representation as an int + * + * @return The OpenCV representation as an int + */ + def getMode(row: Row): Int = row.getInt(4) + + /** + * Gets the image data + * + * @return The image data + */ + def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) + + /** + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ + private[spark] def invalidImageRow(origin: String): Row = + Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) + + /** + * Convert the compressed image (jpeg, png, etc.) into OpenCV + * representation and store it in DataFrame Row + * + * @param origin Arbitrary string that identifies the image + * @param bytes Image bytes (for example, jpeg) + * @return DataFrame Row or None (if the decompression fails) + */ + private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { + + val img = ImageIO.read(new ByteArrayInputStream(bytes)) + + if (img == null) { + None + } else { + val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY + val hasAlpha = img.getColorModel.hasAlpha + + val height = img.getHeight + val width = img.getWidth + val (nChannels, mode) = if (isGray) { + (1, ocvTypes("CV_8UC1")) + } else if (hasAlpha) { + (4, ocvTypes("CV_8UC4")) + } else { + (3, ocvTypes("CV_8UC3")) + } + + val imageSize = height * width * nChannels + assert(imageSize < 1e9, "image is too large") + val decoded = Array.ofDim[Byte](imageSize) + + // Grayscale images in Java require special handling to get the correct intensity + if (isGray) { + var offset = 0 + val raster = img.getRaster + for (h <- 0 until height) { + for (w <- 0 until width) { + decoded(offset) = raster.getSample(w, h, 0).toByte + offset += 1 + } + } + } else { + var offset = 0 + for (h <- 0 until height) { + for (w <- 0 until width) { + val color = new Color(img.getRGB(w, h)) + + decoded(offset) = color.getBlue.toByte + decoded(offset + 1) = color.getGreen.toByte + decoded(offset + 2) = color.getRed.toByte + if (nChannels == 4) { + decoded(offset + 3) = color.getAlpha.toByte + } + offset += nChannels + } + } + } + + // the internal "Row" is needed, because the image is a single DataFrame column + Some(Row(Row(origin, height, width, nChannels, mode, decoded))) + } + } + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @param sparkSession Spark Session, if omitted gets or creates the session + * @param recursive Recursive path search flag + * @param numPartitions Number of the DataFrame partitions, + * if omitted uses defaultParallelism instead + * @param dropImageFailures Drop the files that are not valid images from the result + * @param sampleRatio Fraction of the files loaded + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages( + path: String, + sparkSession: SparkSession, + recursive: Boolean, + numPartitions: Int, + dropImageFailures: Boolean, + sampleRatio: Double, + seed: Long): DataFrame = { + require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") + + val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate + val partitions = + if (numPartitions > 0) { + numPartitions + } else { + session.sparkContext.defaultParallelism + } + + RecursiveFlag.withRecursiveFlag(recursive, session) { + SamplePathFilter.withPathFilter(sampleRatio, session, seed) { + val binResult = session.sparkContext.binaryFiles(path, partitions) + val streams = if (numPartitions == -1) binResult else binResult.repartition(partitions) + val convert = (origin: String, bytes: PortableDataStream) => + decode(origin, bytes.toArray()) + val images = if (dropImageFailures) { + streams.flatMap { case (origin, bytes) => convert(origin, bytes) } + } else { + streams.map { case (origin, bytes) => + convert(origin, bytes).getOrElse(invalidImageRow(origin)) + } + } + session.createDataFrame(images, imageSchema) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala new file mode 100644 index 0000000000000..dba61cd1eb1cc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.nio.file.Paths +import java.util.Arrays + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.image.ImageSchema._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { + // Single column of images named "image" + private lazy val imagePath = "../data/mllib/images" + + test("Smoke test: create basic ImageSchema dataframe") { + val origin = "path" + val width = 1 + val height = 1 + val nChannels = 3 + val data = Array[Byte](0, 0, 0) + val mode = ocvTypes("CV_8UC3") + + // Internal Row corresponds to image StructType + val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)), + Row(Row(null, height, width, nChannels, mode, data))) + val rdd = sc.makeRDD(rows) + val df = spark.createDataFrame(rdd, ImageSchema.imageSchema) + + assert(df.count === 2, "incorrect image count") + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + } + + test("readImages count test") { + var df = readImages(imagePath) + assert(df.count === 1) + + df = readImages(imagePath, null, true, -1, false, 1.0, 0) + assert(df.count === 9) + + df = readImages(imagePath, null, true, -1, true, 1.0, 0) + val countTotal = df.count + assert(countTotal === 7) + + df = readImages(imagePath, null, true, -1, true, 0.5, 0) + // Random number about half of the size of the original dataset + val count50 = df.count + assert(count50 > 0 && count50 < countTotal) + } + + test("readImages partition test") { + val df = readImages(imagePath, null, true, 3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === 3) + } + + // Images with the different number of channels + test("readImages pixel values test") { + + val images = readImages(imagePath + "/multi-channel/").collect + + images.foreach { rrow => + val row = rrow.getAs[Row](0) + val filename = Paths.get(getOrigin(row)).getFileName().toString() + if (firstBytes20.contains(filename)) { + val mode = getMode(row) + val bytes20 = getData(row).slice(0, 20) + + val (expectedMode, expectedBytes) = firstBytes20(filename) + assert(ocvTypes(expectedMode) === mode, "mode of the image is not read correctly") + assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image") + } + } + } + + // number of channels and first 20 bytes of OpenCV representation + // - default representation for 3-channel RGB images is BGR row-wise: + // (B00, G00, R00, B10, G10, R10, ...) + // - default representation for 4-channel RGB images is BGRA row-wise: + // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + private val firstBytes20 = Map( + "grayscale.jpg" -> + (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, + -57, -60, -63, -53, -49, -55, -69))), + "chr30.4.184.jpg" -> (("CV_8UC3", + Array[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57, + -71, -58, -56, -73, -64))), + "BGRA.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + ) +} diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 01627ba92b633..6a5d81706f071 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -97,6 +97,14 @@ pyspark.ml.fpm module :undoc-members: :inherited-members: +pyspark.ml.image module +---------------------------- + +.. automodule:: pyspark.ml.image + :members: + :undoc-members: + :inherited-members: + pyspark.ml.util module ---------------------------- diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py new file mode 100644 index 0000000000000..7d14f05295572 --- /dev/null +++ b/python/pyspark/ml/image.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +.. attribute:: ImageSchema + + An attribute of this module that contains the instance of :class:`_ImageSchema`. + +.. autoclass:: _ImageSchema + :members: +""" + +import numpy as np +from pyspark import SparkContext +from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string +from pyspark.sql import DataFrame, SparkSession + + +class _ImageSchema(object): + """ + Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be private and + not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to access the + APIs of this class. + """ + + def __init__(self): + self._imageSchema = None + self._ocvTypes = None + self._imageFields = None + self._undefinedImageType = None + + @property + def imageSchema(self): + """ + Returns the image schema. + + :return: a :class:`StructType` with a single column of images + named "image" (nullable). + + .. versionadded:: 2.3.0 + """ + + if self._imageSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() + self._imageSchema = _parse_datatype_json_string(jschema.json()) + return self._imageSchema + + @property + def ocvTypes(self): + """ + Returns the OpenCV type mapping supported. + + :return: a dictionary containing the OpenCV type mapping supported. + + .. versionadded:: 2.3.0 + """ + + if self._ocvTypes is None: + ctx = SparkContext._active_spark_context + self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) + return self._ocvTypes + + @property + def imageFields(self): + """ + Returns field names of image columns. + + :return: a list of field names. + + .. versionadded:: 2.3.0 + """ + + if self._imageFields is None: + ctx = SparkContext._active_spark_context + self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) + return self._imageFields + + @property + def undefinedImageType(self): + """ + Returns the name of undefined image type for the invalid image. + + .. versionadded:: 2.3.0 + """ + + if self._undefinedImageType is None: + ctx = SparkContext._active_spark_context + self._undefinedImageType = \ + ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() + return self._undefinedImageType + + def toNDArray(self, image): + """ + Converts an image to an array with metadata. + + :param image: The image to be converted. + :return: a `numpy.ndarray` that is an image. + + .. versionadded:: 2.3.0 + """ + + height = image.height + width = image.width + nChannels = image.nChannels + return np.ndarray( + shape=(height, width, nChannels), + dtype=np.uint8, + buffer=image.data, + strides=(width * nChannels, nChannels, 1)) + + def toImage(self, array, origin=""): + """ + Converts an array with metadata to a two-dimensional image. + + :param array array: The array to convert to image. + :param str origin: Path to the image, optional. + :return: a :class:`Row` that is a two dimensional image. + + .. versionadded:: 2.3.0 + """ + + if array.ndim != 3: + raise ValueError("Invalid array shape") + height, width, nChannels = array.shape + ocvTypes = ImageSchema.ocvTypes + if nChannels == 1: + mode = ocvTypes["CV_8UC1"] + elif nChannels == 3: + mode = ocvTypes["CV_8UC3"] + elif nChannels == 4: + mode = ocvTypes["CV_8UC4"] + else: + raise ValueError("Invalid number of channels") + data = bytearray(array.astype(dtype=np.uint8).ravel()) + # Creating new Row with _create_row(), because Row(name = value, ... ) + # orders fields by name, which conflicts with expected schema order + # when the new DataFrame is created by UDF + return _create_row(self.imageFields, + [origin, height, width, nChannels, mode, data]) + + def readImages(self, path, recursive=False, numPartitions=-1, + dropImageFailures=False, sampleRatio=1.0, seed=0): + """ + Reads the directory of images from the local or remote source. + + .. note:: If multiple jobs are run in parallel with different sampleRatio or recursive flag, + there may be a race condition where one job overwrites the hadoop configs of another. + + .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + potentially non-deterministic. + + :param str path: Path to the image directory. + :param bool recursive: Recursive search flag. + :param int numPartitions: Number of DataFrame partitions. + :param bool dropImageFailures: Drop the files that are not valid images. + :param float sampleRatio: Fraction of the images loaded. + :param int seed: Random number seed. + :return: a :class:`DataFrame` with a single column of "images", + see ImageSchema for details. + + >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df.count() + 4 + + .. versionadded:: 2.3.0 + """ + + ctx = SparkContext._active_spark_context + spark = SparkSession(ctx) + image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + jsession = spark._jsparkSession + jresult = image_schema.readImages(path, jsession, recursive, numPartitions, + dropImageFailures, float(sampleRatio), seed) + return DataFrame(jresult, spark._wrapped) + + +ImageSchema = _ImageSchema() + + +# Monkey patch to disallow instantization of this class. +def _disallow_instance(_): + raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") +_ImageSchema.__init__ = _disallow_instance diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2f1f3af957e4d..2258d61c95333 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -54,6 +54,7 @@ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel +from pyspark.ml.image import ImageSchema from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ SparseMatrix, SparseVector, Vector, VectorUDT, Vectors from pyspark.ml.param import Param, Params, TypeConverters @@ -1818,6 +1819,24 @@ def tearDown(self): del self.data +class ImageReaderTest(SparkSessionTestCase): + + def test_read_images(self): + data_path = 'data/mllib/images/kittens' + df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + self.assertEqual(df.count(), 4) + first_row = df.take(1)[0][0] + array = ImageSchema.toNDArray(first_row) + self.assertEqual(len(array), first_row[1]) + self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) + self.assertEqual(df.schema, ImageSchema.imageSchema) + expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} + self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] + self.assertEqual(ImageSchema.imageFields, expected) + self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From b4edafa99bd3858c166adeefdafd93dcd4bc9734 Mon Sep 17 00:00:00 2001 From: Jakub Nowacki Date: Thu, 23 Nov 2017 12:47:38 +0900 Subject: [PATCH 287/712] [SPARK-22495] Fix setup of SPARK_HOME variable on Windows ## What changes were proposed in this pull request? Fixing the way how `SPARK_HOME` is resolved on Windows. While the previous version was working with the built release download, the set of directories changed slightly for the PySpark `pip` or `conda` install. This has been reflected in Linux files in `bin` but not for Windows `cmd` files. First fix improves the way how the `jars` directory is found, as this was stoping Windows version of `pip/conda` install from working; JARs were not found by on Session/Context setup. Second fix is adding `find-spark-home.cmd` script, which uses `find_spark_home.py` script, as the Linux version, to resolve `SPARK_HOME`. It is based on `find-spark-home` bash script, though, some operations are done in different order due to the `cmd` script language limitations. If environment variable is set, the Python script `find_spark_home.py` will not be run. The process can fail if Python is not installed, but it will mostly use this way if PySpark is installed via `pip/conda`, thus, there is some Python in the system. ## How was this patch tested? Tested on local installation. Author: Jakub Nowacki Closes #19370 from jsnowacki/fix_spark_cmds. --- appveyor.yml | 1 + bin/find-spark-home.cmd | 60 +++++++++++++++++++++++++++++++++++++++++ bin/pyspark2.cmd | 2 +- bin/run-example.cmd | 4 ++- bin/spark-class2.cmd | 2 +- bin/spark-shell2.cmd | 4 ++- bin/sparkR2.cmd | 2 +- 7 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 bin/find-spark-home.cmd diff --git a/appveyor.yml b/appveyor.yml index dc2d81fcdc091..48740920cd09b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -33,6 +33,7 @@ only_commits: - core/src/main/scala/org/apache/spark/api/r/ - mllib/src/main/scala/org/apache/spark/ml/r/ - core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala + - bin/*.cmd cache: - C:\Users\appveyor\.m2 diff --git a/bin/find-spark-home.cmd b/bin/find-spark-home.cmd new file mode 100644 index 0000000000000..c75e7eedb9418 --- /dev/null +++ b/bin/find-spark-home.cmd @@ -0,0 +1,60 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Path to Python script finding SPARK_HOME +set FIND_SPARK_HOME_PYTHON_SCRIPT=%~dp0find_spark_home.py + +rem Default to standard python interpreter unless told otherwise +set PYTHON_RUNNER=python +rem If PYSPARK_DRIVER_PYTHON is set, it overwrites the python version +if not "x%PYSPARK_DRIVER_PYTHON%"=="x" ( + set PYTHON_RUNNER=%PYSPARK_DRIVER_PYTHON% +) +rem If PYSPARK_PYTHON is set, it overwrites the python version +if not "x%PYSPARK_PYTHON%"=="x" ( + set PYTHON_RUNNER=%PYSPARK_PYTHON% +) + +rem If there is python installed, trying to use the root dir as SPARK_HOME +where %PYTHON_RUNNER% > nul 2>$1 +if %ERRORLEVEL% neq 0 ( + if not exist %PYTHON_RUNNER% ( + if "x%SPARK_HOME%"=="x" ( + echo Missing Python executable '%PYTHON_RUNNER%', defaulting to '%~dp0..' for SPARK_HOME ^ +environment variable. Please install Python or specify the correct Python executable in ^ +PYSPARK_DRIVER_PYTHON or PYSPARK_PYTHON environment variable to detect SPARK_HOME safely. + set SPARK_HOME=%~dp0.. + ) + ) +) + +rem Only attempt to find SPARK_HOME if it is not set. +if "x%SPARK_HOME%"=="x" ( + if not exist "%FIND_SPARK_HOME_PYTHON_SCRIPT%" ( + rem If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + rem need to search the different Python directories for a Spark installation. + rem Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + rem spark-submit in another directory we want to use that version of PySpark rather than the + rem pip installed version of PySpark. + set SPARK_HOME=%~dp0.. + ) else ( + rem We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + for /f "delims=" %%i in ('%PYTHON_RUNNER% %FIND_SPARK_HOME_PYTHON_SCRIPT%') do set SPARK_HOME=%%i + ) +) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 46d4d5c883cfb..663670f2fddaf 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] diff --git a/bin/run-example.cmd b/bin/run-example.cmd index efa5f81d08f7f..cc6b234406e4a 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -17,7 +17,9 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] rem The outermost quotes are used to prevent Windows command line parse error diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index a93fd2f0e54bc..5da7d7a430d79 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 7b5d396be888c..aaf71906c6526 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -17,7 +17,9 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index 459b780e2ae33..b48bea345c0b9 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -18,7 +18,7 @@ rem limitations under the License. rem rem Figure out where the Spark framework is installed -set SPARK_HOME=%~dp0.. +call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" From 42f83d7c40bb4e9c7c50f2cbda515b331fb2097f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Nov 2017 18:20:16 +0100 Subject: [PATCH 288/712] [SPARK-17920][FOLLOWUP] simplify the schema file creation in test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/19779 , to simplify the file creation. ## How was this patch tested? test only change Author: Wenchen Fan Closes #19799 from cloud-fan/minor. --- .../spark/sql/hive/client/VersionsSuite.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index fbf6877c994a4..9d15dabc8d3f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -843,15 +843,12 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: SPARK-17920: Insert into/overwrite avro table") { withTempDir { dir => - val path = dir.getAbsolutePath - val schemaPath = s"""$path${File.separator}avroschemadir""" - - new File(schemaPath).mkdir() val avroSchema = - """{ + """ + |{ | "name": "test_record", | "type": "record", - | "fields": [ { + | "fields": [{ | "name": "f0", | "type": [ | "null", @@ -862,17 +859,17 @@ class VersionsSuite extends SparkFunSuite with Logging { | "logicalType": "decimal" | } | ] - | } ] + | }] |} """.stripMargin - val schemaUrl = s"""$schemaPath${File.separator}avroDecimal.avsc""" - val schemaFile = new File(schemaPath, "avroDecimal.avsc") + val schemaFile = new File(dir, "avroDecimal.avsc") val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() + val schemaPath = schemaFile.getCanonicalPath val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile) + val srcLocation = new File(url.getFile).getCanonicalPath val destTableName = "tab1" val srcTableName = "tab2" @@ -886,7 +883,7 @@ class VersionsSuite extends SparkFunSuite with Logging { | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' |LOCATION '$srcLocation' - |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') """.stripMargin ) @@ -898,7 +895,7 @@ class VersionsSuite extends SparkFunSuite with Logging { |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' - |TBLPROPERTIES ('avro.schema.url' = '$schemaUrl') + |TBLPROPERTIES ('avro.schema.url' = '$schemaPath') """.stripMargin ) versionSpark.sql( From c1217565e20bd3297f3b1bc8f18f5dea933211c0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Nov 2017 15:33:26 -0800 Subject: [PATCH 289/712] [SPARK-22592][SQL] cleanup filter converting for hive ## What changes were proposed in this pull request? We have 2 different methods to convert filters for hive, regarding a config. This introduces duplicated and inconsistent code(e.g. one use helper objects for pattern match and one doesn't). ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19801 from cloud-fan/cleanup. --- .../spark/sql/hive/client/HiveShim.scala | 144 +++++++++--------- 1 file changed, 69 insertions(+), 75 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bd1b300416990..1eac70dbf19cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,53 +585,17 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { - if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { - convertComplexFilters(table, filters) - } else { - convertBasicFilters(table, filters) - } - } - - - /** - * An extractor that matches all binary comparison operators except null-safe equality. - * - * Null-safe equality is not supported by Hive metastore partition predicate pushdown - */ - object SpecialBinaryComparison { - def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { - case _: EqualNullSafe => None - case _ => Some((e.left, e.right)) + /** + * An extractor that matches all binary comparison operators except null-safe equality. + * + * Null-safe equality is not supported by Hive metastore partition predicate pushdown + */ + object SpecialBinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case _: EqualNullSafe => None + case _ => Some((e.left, e.right)) + } } - } - - private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - lazy val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || - col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) - .map(col => col.getName).toSet - - filters.collect { - case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ SpecialBinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: StringType)) - if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ SpecialBinaryComparison(Literal(v, _: StringType), a: Attribute) - if !varcharKeys.contains(a.name) => - s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" - }.mkString(" and ") - } - - private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - lazy val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || - col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) - .map(col => col.getName).toSet object ExtractableLiteral { def unapply(expr: Expression): Option[String] = expr match { @@ -643,9 +607,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiterals { def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { - exprs.map(ExtractableLiteral.unapply).foldLeft(Option(Seq.empty[String])) { - case (Some(accum), Some(value)) => Some(accum :+ value) - case _ => None + val extractables = exprs.map(ExtractableLiteral.unapply) + if (extractables.nonEmpty && extractables.forall(_.isDefined)) { + Some(extractables.map(_.get)) + } else { + None } } } @@ -660,40 +626,68 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } def unapply(values: Set[Any]): Option[Seq[String]] = { - values.toSeq.foldLeft(Option(Seq.empty[String])) { - case (Some(accum), value) if valueToLiteralString.isDefinedAt(value) => - Some(accum :+ valueToLiteralString(value)) - case _ => None + val extractables = values.toSeq.map(valueToLiteralString.lift) + if (extractables.nonEmpty && extractables.forall(_.isDefined)) { + Some(extractables.map(_.get)) + } else { + None } } } - def convertInToOr(a: Attribute, values: Seq[String]): String = { - values.map(value => s"${a.name} = $value").mkString("(", " or ", ")") + object NonVarcharAttribute { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + private val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + def unapply(attr: Attribute): Option[String] = { + if (varcharKeys.contains(attr.name)) { + None + } else { + Some(attr.name) + } + } + } + + def convertInToOr(name: String, values: Seq[String]): String = { + values.map(value => s"$name = $value").mkString("(", " or ", ")") } - lazy val convert: PartialFunction[Expression, String] = { - case In(a: Attribute, ExtractableLiterals(values)) - if !varcharKeys.contains(a.name) && values.nonEmpty => - convertInToOr(a, values) - case InSet(a: Attribute, ExtractableValues(values)) - if !varcharKeys.contains(a.name) && values.nonEmpty => - convertInToOr(a, values) - case op @ SpecialBinaryComparison(a: Attribute, ExtractableLiteral(value)) - if !varcharKeys.contains(a.name) => - s"${a.name} ${op.symbol} $value" - case op @ SpecialBinaryComparison(ExtractableLiteral(value), a: Attribute) - if !varcharKeys.contains(a.name) => - s"$value ${op.symbol} ${a.name}" - case And(expr1, expr2) - if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => - (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") - case Or(expr1, expr2) - if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => - s"(${convert(expr1)} or ${convert(expr2)})" + val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + + def convert(expr: Expression): Option[String] = expr match { + case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + Some(convertInToOr(name, values)) + + case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + Some(convertInToOr(name, values)) + + case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + Some(s"$name ${op.symbol} $value") + + case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + Some(s"$value ${op.symbol} $name") + + case And(expr1, expr2) if useAdvanced => + val converted = convert(expr1) ++ convert(expr2) + if (converted.isEmpty) { + None + } else { + Some(converted.mkString("(", " and ", ")")) + } + + case Or(expr1, expr2) if useAdvanced => + for { + left <- convert(expr1) + right <- convert(expr2) + } yield s"($left or $right)" + + case _ => None } - filters.map(convert.lift).collect { case Some(filterString) => filterString }.mkString(" and ") + filters.flatMap(convert).mkString(" and ") } private def quoteStringLiteral(str: String): String = { From 62a826f17c549ed93300bdce562db56bddd5d959 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Nov 2017 11:46:58 +0100 Subject: [PATCH 290/712] [SPARK-22591][SQL] GenerateOrdering shouldn't change CodegenContext.INPUT_ROW ## What changes were proposed in this pull request? When I played with codegen in developing another PR, I found the value of `CodegenContext.INPUT_ROW` is not reliable. Under wholestage codegen, it is assigned to null first and then suddenly changed to `i`. The reason is `GenerateOrdering` changes `CodegenContext.INPUT_ROW` but doesn't restore it back. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19800 from viirya/SPARK-22591. --- .../expressions/codegen/GenerateOrdering.scala | 16 ++++++++++------ .../sql/catalyst/expressions/OrderingSuite.scala | 11 ++++++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1639d1b9dda1f..4a459571ed634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -72,13 +72,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR * Generates the code for ordering based on the given order. */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { + val oldInputRow = ctx.INPUT_ROW + val oldCurrentVars = ctx.currentVars + val inputRow = "i" + ctx.INPUT_ROW = inputRow + // to use INPUT_ROW we must make sure currentVars is null + ctx.currentVars = null + val comparisons = ordering.map { order => - val oldCurrentVars = ctx.currentVars - ctx.INPUT_ROW = "i" - // to use INPUT_ROW we must make sure currentVars is null - ctx.currentVars = null val eval = order.child.genCode(ctx) - ctx.currentVars = oldCurrentVars val asc = order.isAscending val isNullA = ctx.freshName("isNullA") val primitiveA = ctx.freshName("primitiveA") @@ -147,10 +149,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR """ }.mkString }) + ctx.currentVars = oldCurrentVars + ctx.INPUT_ROW = oldInputRow // make sure INPUT_ROW is declared even if splitExpressions // returns an inlined block s""" - |InternalRow ${ctx.INPUT_ROW} = null; + |InternalRow $inputRow = null; |$code """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index aa61ba2bff2bb..d0604b8eb7675 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering} import org.apache.spark.sql.types._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -156,4 +156,13 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { assert(genOrdering.compare(rowB1, rowB2) < 0) } } + + test("SPARK-22591: GenerateOrdering shouldn't change ctx.INPUT_ROW") { + val ctx = new CodegenContext() + ctx.INPUT_ROW = null + + val schema = new StructType().add("field", FloatType, nullable = true) + GenerateOrdering.genComparisons(ctx, schema) + assert(ctx.INPUT_ROW == null) + } } From 554adc77d24c411a6df6d38c596aa33cdf68f3c1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Nov 2017 12:08:49 +0100 Subject: [PATCH 291/712] [SPARK-22595][SQL] fix flaky test: CastSuite.SPARK-22500: cast for struct should not generate codes beyond 64KB ## What changes were proposed in this pull request? This PR reduces the number of fields in the test case of `CastSuite` to fix an issue that is pointed at [here](https://github.com/apache/spark/pull/19800#issuecomment-346634950). ``` java.lang.OutOfMemoryError: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceeded at org.codehaus.janino.UnitCompiler.findClass(UnitCompiler.java:10971) at org.codehaus.janino.UnitCompiler.findTypeByName(UnitCompiler.java:7607) at org.codehaus.janino.UnitCompiler.getReferenceType(UnitCompiler.java:5758) at org.codehaus.janino.UnitCompiler.getType2(UnitCompiler.java:5732) at org.codehaus.janino.UnitCompiler.access$13200(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$18.visitReferenceType(UnitCompiler.java:5668) at org.codehaus.janino.UnitCompiler$18.visitReferenceType(UnitCompiler.java:5660) at org.codehaus.janino.Java$ReferenceType.accept(Java.java:3356) at org.codehaus.janino.UnitCompiler.getType(UnitCompiler.java:5660) at org.codehaus.janino.UnitCompiler.buildLocalVariableMap(UnitCompiler.java:2892) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:2764) at org.codehaus.janino.UnitCompiler.compileDeclaredMethods(UnitCompiler.java:1262) at org.codehaus.janino.UnitCompiler.compileDeclaredMethods(UnitCompiler.java:1234) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:538) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:890) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:894) at org.codehaus.janino.UnitCompiler.access$600(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:377) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:369) at org.codehaus.janino.Java$MemberClassDeclaration.accept(Java.java:1128) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:369) at org.codehaus.janino.UnitCompiler.compileDeclaredMemberTypes(UnitCompiler.java:1209) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:564) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:890) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:894) at org.codehaus.janino.UnitCompiler.access$600(UnitCompiler.java:206) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:377) at org.codehaus.janino.UnitCompiler$2.visitMemberClassDeclaration(UnitCompiler.java:369) at org.codehaus.janino.Java$MemberClassDeclaration.accept(Java.java:1128) at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:369) at org.codehaus.janino.UnitCompiler.compileDeclaredMemberTypes(UnitCompiler.java:1209) at org.codehaus.janino.UnitCompiler.compile2(UnitCompiler.java:564) ... ``` ## How was this patch tested? Used existing test case Author: Kazuaki Ishizaki Closes #19806 from kiszk/SPARK-22595. --- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 84bd8b2f91e4f..7837d6529d12b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -829,7 +829,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-22500: cast for struct should not generate codes beyond 64KB") { - val N = 250 + val N = 25 val fromInner = new StructType( (1 to N).map(i => StructField(s"s$i", DoubleType)).toArray) From 449e26ecdc891039198c26ece99454a2e76d5455 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 24 Nov 2017 15:07:43 +0100 Subject: [PATCH 292/712] [SPARK-22559][CORE] history server: handle exception on opening corrupted listing.ldb ## What changes were proposed in this pull request? Currently history server v2 failed to start if `listing.ldb` is corrupted. This patch get rid of the corrupted `listing.ldb` and re-create it. The exception handling follows [opening disk store for app](https://github.com/apache/spark/blob/0ffa7c488fa8156e2a1aa282e60b7c36b86d8af8/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala#L307) ## How was this patch tested? manual test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #19786 from gengliangwang/listingException. --- .../spark/deploy/history/FsHistoryProvider.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 25f82b55f2003..69ccde3a8149d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -34,10 +34,10 @@ import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException +import org.fusesource.leveldbjni.internal.NativeDB import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ @@ -132,7 +132,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) AppStatusStore.CURRENT_VERSION, logDir.toString()) try { - open(new File(path, "listing.ldb"), metadata) + open(dbPath, metadata) } catch { // If there's an error, remove the listing database and any existing UI database // from the store directory, since it's extremely likely that they'll all contain @@ -140,7 +140,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) case _: UnsupportedStoreVersionException | _: MetadataMismatchException => logInfo("Detected incompatible DB versions, deleting...") path.listFiles().foreach(Utils.deleteRecursively) - open(new File(path, "listing.ldb"), metadata) + open(dbPath, metadata) + case dbExc: NativeDB.DBException => + // Get rid of the corrupted listing.ldb and re-create it. + logWarning(s"Failed to load disk store $dbPath :", dbExc) + Utils.deleteRecursively(dbPath) + open(dbPath, metadata) } }.getOrElse(new InMemoryStore()) @@ -568,7 +573,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() - logInfo(s"Replaying log path: $logPath") val bus = new ReplayListenerBus() val listener = new AppListingListener(fileStatus, clock) From efd0036ec88bdc385f5a9ea568d2e2bbfcda2912 Mon Sep 17 00:00:00 2001 From: GuoChenzhao Date: Fri, 24 Nov 2017 15:09:43 +0100 Subject: [PATCH 293/712] [SPARK-22537][CORE] Aggregation of map output statistics on driver faces single point bottleneck ## What changes were proposed in this pull request? In adaptive execution, the map output statistics of all mappers will be aggregated after previous stage is successfully executed. Driver takes the aggregation job while it will get slow when the number of `mapper * shuffle partitions` is large, since it only uses single thread to compute. This PR uses multi-thread to deal with this single point bottleneck. ## How was this patch tested? Test cases are in `MapOutputTrackerSuite.scala` Author: GuoChenzhao Author: gczsjdy Closes #19763 from gczsjdy/single_point_mapstatistics. --- .../org/apache/spark/MapOutputTracker.scala | 60 ++++++++++++++++++- .../spark/internal/config/package.scala | 11 ++++ .../apache/spark/MapOutputTrackerSuite.scala | 23 +++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f760a59bda2f..195fd4f818b36 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException @@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } + /** + * Grouped function of Range, this is to avoid traverse of all elements of Range using + * IterableLike's grouped function. + */ + def rangeGrouped(range: Range, size: Int): Seq[Range] = { + val start = range.start + val step = range.step + val end = range.end + for (i <- start.until(end, size * step)) yield { + i.until(i + size * step, step) + } + } + + /** + * To equally divide n elements into m buckets, basically each bucket should have n/m elements, + * for the remaining n%m elements, add one more element to the first n%m buckets each. + */ + def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = { + val elementsPerBucket = numElements / numBuckets + val remaining = numElements % numBuckets + val splitPoint = (elementsPerBucket + 1) * remaining + if (elementsPerBucket == 0) { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) + } else { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++ + rangeGrouped(splitPoint.until(numElements), elementsPerBucket) + } + } + /** * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) + val parallelAggThreshold = conf.get( + SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) + val parallelism = math.min( + Runtime.getRuntime.availableProcessors(), + statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt + if (parallelism <= 1) { + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } else { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") + try { + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map { + reduceIds => Future { + for (s <- statuses; i <- reduceIds) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } + ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf) + } finally { + threadPool.shutdown() } } new MapOutputStatistics(dep.shuffleId, totalSizes) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7a9072736b9aa..8fa25c0281493 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -499,4 +499,15 @@ package object config { "array in the sorter.") .intConf .createWithDefault(Integer.MAX_VALUE) + + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = + ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") + .internal() + .doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " + + "or equal to this threshold. Note that the actual parallelism is calculated by number of " + + "mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.") + .intConf + .checkValue(v => v > 0, "The threshold should be positive.") + .createWithDefault(10000000) + } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ebd826b0ba2f6..50b8ea754d8d9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("equally divide map statistics tasks") { + val func = newTrackerMaster().equallyDivide _ + val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5)) + val expects = Seq( + Seq(0, 0, 0, 0, 0), + Seq(1, 1, 1, 1, 0), + Seq(3, 3, 3, 3, 3), + Seq(4, 3, 3, 3, 3), + Seq(4, 4, 3, 3, 3), + Seq(4, 4, 4, 3, 3), + Seq(4, 4, 4, 4, 3), + Seq(4, 4, 4, 4, 4)) + cases.zip(expects).foreach { case ((num, divisor), expect) => + val answer = func(num, divisor).toSeq + var wholeSplit = (0 until num) + answer.zip(expect).foreach { case (split, expectSplitLength) => + val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength) + assert(currentSplit.toSet == split.toSet) + wholeSplit = rest + } + } + } + } From a1877f45c3451d18879083ed9b71dd9d5f583f1c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 24 Nov 2017 19:55:26 +0100 Subject: [PATCH 294/712] [SPARK-22597][SQL] Add spark-sql cmd script for Windows users ## What changes were proposed in this pull request? This PR proposes to add cmd scripts so that Windows users can also run `spark-sql` script. ## How was this patch tested? Manually tested on Windows. **Before** ```cmd C:\...\spark>.\bin\spark-sql '.\bin\spark-sql' is not recognized as an internal or external command, operable program or batch file. C:\...\spark>.\bin\spark-sql.cmd '.\bin\spark-sql.cmd' is not recognized as an internal or external command, operable program or batch file. ``` **After** ```cmd C:\...\spark>.\bin\spark-sql ... spark-sql> SELECT 'Hello World !!'; ... Hello World !! ``` Author: hyukjinkwon Closes #19808 from HyukjinKwon/spark-sql-cmd. --- bin/find-spark-home.cmd | 2 +- bin/run-example.cmd | 2 +- bin/spark-sql.cmd | 25 +++++++++++++++++++++++++ bin/spark-sql2.cmd | 25 +++++++++++++++++++++++++ bin/sparkR2.cmd | 3 +-- 5 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 bin/spark-sql.cmd create mode 100644 bin/spark-sql2.cmd diff --git a/bin/find-spark-home.cmd b/bin/find-spark-home.cmd index c75e7eedb9418..6025f67c38de4 100644 --- a/bin/find-spark-home.cmd +++ b/bin/find-spark-home.cmd @@ -32,7 +32,7 @@ if not "x%PYSPARK_PYTHON%"=="x" ( ) rem If there is python installed, trying to use the root dir as SPARK_HOME -where %PYTHON_RUNNER% > nul 2>$1 +where %PYTHON_RUNNER% > nul 2>&1 if %ERRORLEVEL% neq 0 ( if not exist %PYTHON_RUNNER% ( if "x%SPARK_HOME%"=="x" ( diff --git a/bin/run-example.cmd b/bin/run-example.cmd index cc6b234406e4a..2dd396e785358 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed call "%~dp0find-spark-home.cmd" -set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] +set _SPARK_CMD_USAGE=Usage: .\bin\run-example [options] example-class [example args] rem The outermost quotes are used to prevent Windows command line parse error rem when there are some quotes in parameters, see SPARK-21877. diff --git a/bin/spark-sql.cmd b/bin/spark-sql.cmd new file mode 100644 index 0000000000000..919e3214b5863 --- /dev/null +++ b/bin/spark-sql.cmd @@ -0,0 +1,25 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkSQL. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-sql2.cmd" %*" diff --git a/bin/spark-sql2.cmd b/bin/spark-sql2.cmd new file mode 100644 index 0000000000000..c34a3c5aa0739 --- /dev/null +++ b/bin/spark-sql2.cmd @@ -0,0 +1,25 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +call "%~dp0find-spark-home.cmd" + +set _SPARK_CMD_USAGE=Usage: .\bin\spark-sql [options] [cli option] + +call "%SPARK_HOME%\bin\spark-submit2.cmd" --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index b48bea345c0b9..446f0c30bfe82 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -21,6 +21,5 @@ rem Figure out where the Spark framework is installed call "%~dp0find-spark-home.cmd" call "%SPARK_HOME%\bin\load-spark-env.cmd" - - +set _SPARK_CMD_USAGE=Usage: .\bin\sparkR [options] call "%SPARK_HOME%\bin\spark-submit2.cmd" sparkr-shell-main %* From 70221903f54eaa0514d5d189dfb6f175a62228a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 Nov 2017 21:50:30 -0800 Subject: [PATCH 295/712] [SPARK-22596][SQL] set ctx.currentVars in CodegenSupport.consume ## What changes were proposed in this pull request? `ctx.currentVars` means the input variables for the current operator, which is already decided in `CodegenSupport`, we can set it there instead of `doConsume`. also add more comments to help people understand the codegen framework. After this PR, we now have a principle about setting `ctx.currentVars` and `ctx.INPUT_ROW`: 1. for non-whole-stage-codegen path, never set them. (permit some special cases like generating ordering) 2. for whole-stage-codegen `produce` path, mostly we don't need to set them, but blocking operators may need to set them for expressions that produce data from data source, sort buffer, aggregate buffer, etc. 3. for whole-stage-codegen `consume` path, mostly we don't need to set them because `currentVars` is automatically set to child input variables and `INPUT_ROW` is mostly not used. A few plans need to tweak them as they may have different inputs, or they use the input row. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19803 from cloud-fan/codegen. --- .../catalyst/expressions/BoundAttribute.scala | 23 +++++++++------- .../expressions/codegen/CodeGenerator.scala | 14 +++++++--- .../sql/execution/DataSourceScanExec.scala | 14 +++++----- .../spark/sql/execution/ExpandExec.scala | 3 --- .../spark/sql/execution/GenerateExec.scala | 2 -- .../sql/execution/WholeStageCodegenExec.scala | 27 ++++++++++++++----- .../execution/basicPhysicalOperators.scala | 6 +---- .../apache/spark/sql/execution/objects.scala | 20 +++++--------- 8 files changed, 59 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7d16118c9d59f..6a17a397b3ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -59,21 +59,24 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) ev.isNull = oev.isNull ev.value = oev.value - val code = oev.code - oev.code = "" - ev.copy(code = code) - } else if (nullable) { - ev.copy(code = s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") + ev.copy(code = oev.code) } else { - ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") + assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") + val javaType = ctx.javaType(dataType) + val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + if (nullable) { + ev.copy(code = + s""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9df8a8d6f6609..0498e61819f48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -133,6 +133,17 @@ class CodegenContext { term } + /** + * Holding the variable name of the input row of the current operator, will be used by + * `BoundReference` to generate code. + * + * Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW` + * to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set + * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling + * `Expression.genCode`. + */ + final var INPUT_ROW = "i" + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -386,9 +397,6 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" - /** The variable name of the input row in generated code. */ - final var INPUT_ROW = "i" - /** * The map from a variable name to it's next ID. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a477c23140536..747749bc72e66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -123,7 +123,7 @@ case class RowDataSourceScanExec( |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, null).trim} + | ${consume(ctx, columnsRowInput).trim} | if (shouldStop()) return; |} """.stripMargin @@ -355,19 +355,21 @@ case class FileSourceScanExec( // PhysicalRDD always just has one input val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } val row = ctx.freshName("row") + ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map{ case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } val inputRow = if (needsUnsafeRowConversion) null else row s""" |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} + | ${consume(ctx, outputVars, inputRow).trim} | if (shouldStop()) return; |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 33849f4389b92..a7bd5ebf93ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -133,9 +133,6 @@ case class ExpandExec( * size explosion. */ - // Set input variables - ctx.currentVars = input - // Tracks whether a column has the same output for all rows. // Size of sameOutput array should equal N. // If sameOutput(i) is true, then the i-th column has the same value for all output rows given diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index c142d3b5ed4f2..e1562befe14f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -135,8 +135,6 @@ case class GenerateExec( override def needCopyResult: Boolean = true override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - ctx.currentVars = input - // Add input rows to the values when we are joining val values = if (join) { input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 16b5706c03bf9..7166b7771e4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -108,20 +108,22 @@ trait CodegenSupport extends SparkPlan { /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. + * + * Note that `outputVars` and `row` can't both be null. */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = - if (row != null) { + if (outputVars != null) { + assert(outputVars.length == output.length) + // outputVars will be used to generate the code for UnsafeRow, so we should copy them + outputVars.map(_.copy()) + } else { + assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } - } else { - assert(outputVars != null) - assert(outputVars.length == output.length) - // outputVars will be used to generate the code for UnsafeRow, so we should copy them - outputVars.map(_.copy()) } val rowVar = if (row != null) { @@ -147,6 +149,11 @@ trait CodegenSupport extends SparkPlan { } } + // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` + // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to + // generate code of `rowVar` manually. + ctx.currentVars = inputVars + ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" @@ -193,7 +200,8 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Generate the Java source code to process the rows from child SparkPlan. + * Generate the Java source code to process the rows from child SparkPlan. This should only be + * called from `consume`. * * This should be override by subclass to support codegen. * @@ -207,6 +215,11 @@ trait CodegenSupport extends SparkPlan { * } * * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + * When consuming as a listing of variables, the code to produce the input is already + * generated and `CodegenContext.currentVars` is already set. When consuming as UnsafeRow, + * implementations need to put `row.code` in the generated code and set + * `CodegenContext.INPUT_ROW` manually. Some plans may need more tweaks as they have + * different inputs(join build side, aggregate buffer, etc.), or other special cases. */ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f205bdf3da709..c9a15147e30d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -56,9 +56,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => - ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) - ctx.currentVars = input + val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) @@ -152,8 +150,6 @@ case class FilterExec(condition: Expression, child: SparkPlan) """.stripMargin } - ctx.currentVars = input - // To generate the predicates we will follow this algorithm. // For each predicate that is not IsNotNull, we will generate them one by one loading attributes // as necessary. For each of both attributes, if there is an IsNotNull predicate we will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d861109436a08..d1bd8a7076863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -81,11 +81,8 @@ case class DeserializeToObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(deserializer, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil - consume(ctx, resultVars) + val resultObj = BindReferences.bindReference(deserializer, child.output).genCode(ctx) + consume(ctx, resultObj :: Nil) } override protected def doExecute(): RDD[InternalRow] = { @@ -118,11 +115,9 @@ case class SerializeFromObjectExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val bound = serializer.map { expr => - ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + val resultVars = serializer.map { expr => + BindReferences.bindReference[Expression](expr, child.output).genCode(ctx) } - ctx.currentVars = input - val resultVars = bound.map(_.genCode(ctx)) consume(ctx, resultVars) } @@ -224,12 +219,9 @@ case class MapElementsExec( val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) - val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(callFunc, child.output)) - ctx.currentVars = input - val resultVars = bound.genCode(ctx) :: Nil + val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) - consume(ctx, resultVars) + consume(ctx, result :: Nil) } override protected def doExecute(): RDD[InternalRow] = { From e3fd93f149ff0ff1caff28a5191215e2a29749a9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 Nov 2017 22:43:47 -0800 Subject: [PATCH 296/712] [SPARK-22604][SQL] remove the get address methods from ColumnVector ## What changes were proposed in this pull request? `nullsNativeAddress` and `valuesNativeAddress` are only used in tests and benchmark, no need to be top class API. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19818 from cloud-fan/minor. --- .../vectorized/ArrowColumnVector.java | 10 --- .../execution/vectorized/ColumnVector.java | 7 -- .../vectorized/OffHeapColumnVector.java | 6 +- .../vectorized/OnHeapColumnVector.java | 9 -- .../vectorized/ColumnarBatchBenchmark.scala | 32 ++++---- .../vectorized/ColumnarBatchSuite.scala | 82 +++++++------------ 6 files changed, 47 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 949035bfb177c..3a10e9830f581 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -59,16 +59,6 @@ public boolean anyNullsSet() { return numNulls() > 0; } - @Override - public long nullsNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - - @Override - public long valuesNativeAddress() { - throw new RuntimeException("Cannot get native address for arrow column"); - } - @Override public void close() { if (childColumns != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 666fd63fdcf2f..360ed83e2af2a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -62,13 +62,6 @@ public abstract class ColumnVector implements AutoCloseable { */ public abstract boolean anyNullsSet(); - /** - * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid - * to call for off heap columns. - */ - public abstract long nullsNativeAddress(); - public abstract long valuesNativeAddress(); - /** * Returns whether the value at rowId is NULL. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 2bf523b7e7198..6b5c783d4fa87 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,6 +19,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -73,12 +75,12 @@ public OffHeapColumnVector(int capacity, DataType type) { reset(); } - @Override + @VisibleForTesting public long valuesNativeAddress() { return data; } - @Override + @VisibleForTesting public long nullsNativeAddress() { return nulls; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index d699d292711dc..a7b103a62b17a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -79,15 +79,6 @@ public OnHeapColumnVector(int capacity, DataType type) { reset(); } - @Override - public long valuesNativeAddress() { - throw new RuntimeException("Cannot get native address for on heap column"); - } - @Override - public long nullsNativeAddress() { - throw new RuntimeException("Cannot get native address for on heap column"); - } - @Override public void close() { super.close(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1331f157363b0..705b26b8c91e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -36,15 +36,6 @@ import org.apache.spark.util.collection.BitSet * Benchmark to low level memory access using different ways to manage buffers. */ object ColumnarBatchBenchmark { - - def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { - if (memMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(capacity, dt) - } else { - new OnHeapColumnVector(capacity, dt) - } - } - // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -151,7 +142,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -170,7 +161,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -189,7 +180,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -255,7 +246,7 @@ object ColumnarBatchBenchmark { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -330,7 +321,7 @@ object ColumnarBatchBenchmark { for (n <- 0L until iters) { var i = 0 while (i < count) { - if (i % 2 == 0) b(i) = 1; + if (i % 2 == 0) b(i) = 1 i += 1 } i = 0 @@ -351,7 +342,7 @@ object ColumnarBatchBenchmark { } def stringAccess(iters: Long): Unit = { - val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" val random = new Random(0) def randomString(min: Int, max: Int): String = { @@ -359,10 +350,10 @@ object ColumnarBatchBenchmark { val sb = new StringBuilder(len) var i = 0 while (i < len) { - sb.append(chars.charAt(random.nextInt(chars.length()))); + sb.append(chars.charAt(random.nextInt(chars.length()))) i += 1 } - return sb.toString + sb.toString } val minString = 3 @@ -373,7 +364,12 @@ object ColumnarBatchBenchmark { .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => - val column = allocate(count, BinaryType, memoryMode) + val column = if (memoryMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(count, BinaryType) + } else { + new OnHeapColumnVector(count, BinaryType) + } + var sum = 0L for (n <- 0L until iters) { var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4a6c8f5521d18..80a50866aa504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -50,11 +50,11 @@ class ColumnarBatchSuite extends SparkFunSuite { name: String, size: Int, dt: DataType)( - block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + block: WritableColumnVector => Unit): Unit = { test(name) { Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => val vector = allocate(size, dt, mode) - try block(vector, mode) finally { + try block(vector) finally { vector.close() } } @@ -62,7 +62,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Null APIs", 1024, IntegerType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 assert(!column.anyNullsSet()) @@ -121,15 +121,11 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.nullsNativeAddress() - assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) - } } } testVector("Byte APIs", 1024, ByteType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[Byte] var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray @@ -173,16 +169,12 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 3 reference.zipWithIndex.foreach { v => - assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getByte(null, addr + v._2)) - } + assert(v._1 == column.getByte(v._2), "VectorType=" + column.getClass.getSimpleName) } } testVector("Short APIs", 1024, ShortType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] @@ -248,16 +240,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getShort(v._2), "Seed = " + seed + " Mem Mode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) - } + assert(v._1 == column.getShort(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Int APIs", 1024, IntegerType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] @@ -329,16 +318,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Mem Mode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) - } + assert(v._1 == column.getInt(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Long APIs", 1024, LongType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] @@ -413,16 +399,12 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) - } + " Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Float APIs", 1024, FloatType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] @@ -500,16 +482,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) - } + assert(v._1 == column.getFloat(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("Double APIs", 1024, DoubleType) { - (column, memMode) => + column => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] @@ -587,16 +566,13 @@ class ColumnarBatchSuite extends SparkFunSuite { } reference.zipWithIndex.foreach { v => - assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " MemMode=" + memMode) - if (memMode == MemoryMode.OFF_HEAP) { - val addr = column.valuesNativeAddress() - assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) - } + assert(v._1 == column.getDouble(v._2), + "Seed = " + seed + " VectorType=" + column.getClass.getSimpleName) } } testVector("String APIs", 6, StringType) { - (column, memMode) => + column => val reference = mutable.ArrayBuffer.empty[String] assert(column.arrayData().elementsAppended == 0) @@ -643,9 +619,9 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => - assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == column.getUTF8String(v._2).toString, - "MemoryMode" + memMode) + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v._1.length == column.getArrayLength(v._2), errMsg) + assert(v._1 == column.getUTF8String(v._2).toString, errMsg) } column.reset() @@ -653,7 +629,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Int Array", 10, new ArrayType(IntegerType, true)) { - (column, _) => + column => // Fill the underlying data with all the arrays back to back. val data = column.arrayData() @@ -763,7 +739,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Struct Column", 10, - new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => + new StructType().add("int", IntegerType).add("double", DoubleType)) { column => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -789,7 +765,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { - (column, _) => + column => val childColumn = column.arrayData() val data = column.arrayData().arrayData() (0 until 6).foreach { @@ -822,7 +798,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Nest Struct in Array", 10, - new ArrayType(structType, true)) { (column, _) => + new ArrayType(structType, true)) { column => val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -851,7 +827,7 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new StructType() .add("int", IntegerType) - .add("array", new ArrayType(IntegerType, true))) { (column, _) => + .add("array", new ArrayType(IntegerType, true))) { column => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -880,7 +856,7 @@ class ColumnarBatchSuite extends SparkFunSuite { testVector( "Nest Struct in Struct", 10, - new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => + new StructType().add("int", IntegerType).add("struct", subSchema)) { column => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) From 4d8ace48698c9a5e45cfc896b170914f1c21dc7b Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Sat, 25 Nov 2017 07:32:28 -0600 Subject: [PATCH 297/712] [SPARK-22583] First delegation token renewal time is not 75% of renewal time in Mesos The first scheduled renewal time is is set to the exact expiration time, and all subsequent renewal times are 75% of the renewal time. This makes it so that the inital renewal time is also 75%. ## What changes were proposed in this pull request? Set the initial renewal time to be 75% of renewal time. ## How was this patch tested? Tested locally in a test HDFS cluster, checking various renewal times. Author: Kalvin Chau Closes #19798 from kalvinnchau/fix-inital-renewal-time. --- .../cluster/mesos/MesosHadoopDelegationTokenManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 325dc179d63ea..7165bfae18a5e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -63,7 +63,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), rt) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") From fba63c1a7bc5c907b909bfd9247b85b36efba469 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 26 Nov 2017 07:42:44 -0600 Subject: [PATCH 298/712] [SPARK-22607][BUILD] Set large stack size consistently for tests to avoid StackOverflowError ## What changes were proposed in this pull request? Set `-ea` and `-Xss4m` consistently for tests, to fix in particular: ``` OrderingSuite: ... - GenerateOrdering with ShortType *** RUN ABORTED *** java.lang.StackOverflowError: at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:370) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:541) ... ``` ## How was this patch tested? Existing tests. Manually verified it resolves the StackOverflowError this intends to resolve. Author: Sean Owen Closes #19820 from srowen/SPARK-22607. --- pom.xml | 4 ++-- project/SparkBuild.scala | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index 0297311dd6e64..3b2c629f8ec30 100644 --- a/pom.xml +++ b/pom.xml @@ -2101,7 +2101,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize} + -ea -Xmx3g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes_2.11 + jar + Spark Project Kubernetes + + kubernetes + 3.0.0 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + com.fasterxml.jackson.core + * + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + + + + + com.google.guava + guava + + + + + org.mockito + mockito-core + test + + + + com.squareup.okhttp3 + okhttp + 3.8.1 + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala new file mode 100644 index 0000000000000..f0742b91987b6 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +private[spark] object Config extends Logging { + + val KUBERNETES_NAMESPACE = + ConfigBuilder("spark.kubernetes.namespace") + .doc("The namespace that will be used for running the driver and executor pods. When using " + + "spark-submit in cluster mode, this can also be passed to spark-submit via the " + + "--kubernetes-namespace command line argument.") + .stringConf + .createWithDefault("default") + + val EXECUTOR_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.docker.image") + .doc("Docker image to use for the executors. Specify this using the standard Docker tag " + + "format.") + .stringConf + .createOptional + + val DOCKER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + .doc("Kubernetes image pull policy. Valid values are Always, Never, and IfNotPresent.") + .stringConf + .checkValues(Set("Always", "Never", "IfNotPresent")) + .createWithDefault("IfNotPresent") + + val APISERVER_AUTH_DRIVER_CONF_PREFIX = + "spark.kubernetes.authenticate.driver" + val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + "spark.kubernetes.authenticate.driver.mounted" + val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" + val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" + val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" + val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" + val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" + + val KUBERNETES_SERVICE_ACCOUNT_NAME = + ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + .doc("Service account that is used when running the driver pod. The driver pod uses " + + "this service account when requesting executor pods from the API server. If specific " + + "credentials are given for the driver pod to use, the driver will favor " + + "using those credentials instead.") + .stringConf + .createOptional + + // Note that while we set a default for this when we start up the + // scheduler, the specific default value is dynamically determined + // based on the executor memory. + val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.executor.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This " + + "is memory that accounts for things like VM overheads, interned strings, other native " + + "overheads, etc. This tends to grow with the executor size. (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + val KUBERNETES_DRIVER_POD_NAME = + ConfigBuilder("spark.kubernetes.driver.pod.name") + .doc("Name of the driver pod.") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") + .internal() + .stringConf + .createWithDefault("spark") + + val KUBERNETES_ALLOCATION_BATCH_SIZE = + ConfigBuilder("spark.kubernetes.allocation.batch.size") + .doc("Number of pods to launch at once in each round of executor allocation.") + .intConf + .checkValue(value => value > 0, "Allocation batch size should be a positive integer") + .createWithDefault(5) + + val KUBERNETES_ALLOCATION_BATCH_DELAY = + ConfigBuilder("spark.kubernetes.allocation.batch.delay") + .doc("Number of seconds to wait between each round of executor allocation.") + .longConf + .checkValue(value => value > 0, "Allocation batch delay should be a positive integer") + .createWithDefault(1) + + val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for a single executor pod") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS = + ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts") + .doc("Maximum number of attempts allowed for checking the reason of an executor loss " + + "before it is assumed that the executor failed.") + .intConf + .checkValue(value => value > 0, "Maximum attempts of checks of executor lost reason " + + "must be a positive integer") + .createWithDefault(10) + + val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala new file mode 100644 index 0000000000000..01717479fddd9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import org.apache.spark.SparkConf + +private[spark] object ConfigurationUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala new file mode 100644 index 0000000000000..4ddeefb15a89d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] object Constants { + + // Labels + val SPARK_APP_ID_LABEL = "spark-app-selector" + val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" + val SPARK_ROLE_LABEL = "spark-role" + val SPARK_POD_DRIVER_ROLE = "driver" + val SPARK_POD_EXECUTOR_ROLE = "executor" + + // Default and fixed ports + val DEFAULT_DRIVER_PORT = 7078 + val DEFAULT_BLOCKMANAGER_PORT = 7079 + val BLOCK_MANAGER_PORT_NAME = "blockmanager" + val EXECUTOR_PORT_NAME = "executor" + + // Environment Variables + val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" + val ENV_DRIVER_URL = "SPARK_DRIVER_URL" + val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" + val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" + val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" + val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" + val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" + val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" + val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" + val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + + // Miscellaneous + val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + val MEMORY_OVERHEAD_FACTOR = 0.10 + val MEMORY_OVERHEAD_MIN_MIB = 384L +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala new file mode 100644 index 0000000000000..1e3f055e05766 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.utils.HttpClientUtils +import okhttp3.Dispatcher + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to + * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL + * options for different components. + */ +private[spark] object SparkKubernetesClientFactory { + + def createKubernetesClient( + master: String, + namespace: Option[String], + kubernetesAuthConfPrefix: String, + sparkConf: SparkConf, + defaultServiceAccountToken: Option[File], + defaultServiceAccountCaCert: Option[File]): KubernetesClient = { + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" + val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" + val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) + .map(new File(_)) + .orElse(defaultServiceAccountToken) + val oauthTokenValue = sparkConf.getOption(oauthTokenConf) + ConfigurationUtils.requireNandDefined( + oauthTokenFile, + oauthTokenValue, + s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + + s"value $oauthTokenConf.") + + val caCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX") + .orElse(defaultServiceAccountCaCert.map(_.getAbsolutePath)) + val clientKeyFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") + val clientCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + val dispatcher = new Dispatcher( + ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + val config = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(master) + .withWebsocketPingInterval(0) + .withOption(oauthTokenValue) { + (token, configBuilder) => configBuilder.withOauthToken(token) + }.withOption(oauthTokenFile) { + (file, configBuilder) => + configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + }.withOption(caCertFile) { + (file, configBuilder) => configBuilder.withCaCertFile(file) + }.withOption(clientKeyFile) { + (file, configBuilder) => configBuilder.withClientKeyFile(file) + }.withOption(clientCertFile) { + (file, configBuilder) => configBuilder.withClientCertFile(file) + }.withOption(namespace) { + (ns, configBuilder) => configBuilder.withNamespace(ns) + }.build() + val baseHttpClient = HttpClientUtils.createHttpClient(config) + val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() + .dispatcher(dispatcher) + .build() + new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + } + + private implicit class OptionConfigurableConfigBuilder(val configBuilder: ConfigBuilder) + extends AnyVal { + + def withOption[T] + (option: Option[T]) + (configurator: ((T, ConfigBuilder) => ConfigBuilder)): ConfigBuilder = { + option.map { opt => + configurator(opt, configBuilder) + }.getOrElse(configBuilder) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala new file mode 100644 index 0000000000000..f79155b117b67 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.ConfigurationUtils +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Utils + +/** + * A factory class for configuring and creating executor pods. + */ +private[spark] trait ExecutorPodFactory { + + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod +} + +private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) + extends ExecutorPodFactory { + + private val executorExtraClasspath = + sparkConf.get(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + + private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + + private val executorAnnotations = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + private val nodeSelector = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_NODE_SELECTOR_PREFIX) + + private val executorDockerImage = sparkConf + .get(EXECUTOR_DOCKER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor Docker image")) + private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val blockManagerPort = sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + + private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryString = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_MEMORY.key, + org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = sparkConf + .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod = { + val name = s"$executorPodNamePrefix-exec-$executorId" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val resolvedExecutorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> applicationId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorLabels + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryMiB}Mi") + .build() + val executorMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCores.toString) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = sparkConf + .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + // Executor backend expects integral value for executor cores, so round it up to an int. + (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, applicationId), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder() + .withName("executor") + .withImage(executorDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryLimitQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .build() + + val executorPod = new PodBuilder() + .withNewMetadata() + .withName(name) + .withLabels(resolvedExecutorLabels.asJava) + .withAnnotations(executorAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + + new PodBuilder(executorPod) + .editSpec() + .addToContainers(containerWithExecutorLimitCores) + .endSpec() + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala new file mode 100644 index 0000000000000..68ca6a7622171 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.File + +import io.fabric8.kubernetes.client.Config + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.util.ThreadUtils + +private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { + + override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") + + override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + new TaskSchedulerImpl(sc) + } + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + val sparkConf = sc.getConf + + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( + KUBERNETES_MASTER_INTERNAL_URL, + Some(sparkConf.get(KUBERNETES_NAMESPACE)), + APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + sparkConf, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + + val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") + new KubernetesClusterSchedulerBackend( + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala new file mode 100644 index 0000000000000..e79c987852db2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable +import java.net.InetAddress +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} +import javax.annotation.concurrent.GuardedBy + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} +import org.apache.spark.util.Utils + +private[spark] class KubernetesClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + executorPodFactory: ExecutorPodFactory, + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + import KubernetesClusterSchedulerBackend._ + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + private val RUNNING_EXECUTOR_PODS_LOCK = new Object + @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") + private val runningExecutorsToPods = new mutable.HashMap[String, Pod] + private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() + private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() + private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() + + private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( + requestExecutorsService) + + private val driverPod = kubernetesClient.pods() + .inNamespace(kubernetesNamespace) + .withName(kubernetesDriverPodName) + .get() + + protected override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + private val executorWatchResource = new AtomicReference[Closeable] + private val totalExpectedExecutors = new AtomicInteger(0) + + private val driverUrl = RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + + private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) + + private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val executorLostReasonCheckMaxAttempts = conf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + private val allocatorRunnable = new Runnable { + + // Maintains a map of executor id to count of checks performed to learn the loss reason + // for an executor. + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] + + override def run(): Unit = { + handleDisconnectedExecutors() + + val executorsToAllocate = mutable.Map[String, Pod]() + val currentTotalRegisteredExecutors = totalRegisteredExecutors.get + val currentTotalExpectedExecutors = totalExpectedExecutors.get + val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { + logDebug("Waiting for pending executors before scaling") + } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { + logDebug("Maximum allowed executor limit reached. Not scaling up further.") + } else { + for (_ <- 0 until math.min( + currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { + val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString + val executorPod = executorPodFactory.createExecutorPod( + executorId, + applicationId(), + driverUrl, + conf.getExecutorEnv, + driverPod, + currentNodeToLocalTaskCount) + executorsToAllocate(executorId) = executorPod + logInfo( + s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") + } + } + } + + val allocatedExecutors = executorsToAllocate.mapValues { pod => + Utils.tryLog { + kubernetesClient.pods().create(pod) + } + } + + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + allocatedExecutors.map { + case (executorId, attemptedAllocatedExecutor) => + attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => + runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) + } + } + } + } + + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { + case (executorId, executorPod) => + val knownExitReason = Option(podsWithKnownExitReasons.remove( + executorPod.getMetadata.getName)) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We don't delete the pod running the executor that has an exit condition caused by + // the application from the Kubernetes API server. This allows users to debug later on + // through commands such as "kubectl logs " and + // "kubectl describe pod ". Note that exited containers have terminated and + // therefore won't take CPU and memory resources. + // Otherwise, the executor pod is marked to be deleted from the API server. + if (executorExited.exitCausedByApp) { + logInfo(s"Executor $executorId exited because of the application.") + deleteExecutorFromDataStructures(executorId) + } else { + logInfo(s"Executor $executorId failed because of a framework error.") + deleteExecutorFromClusterAndDataStructures(executorId) + } + } + } + } + + def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) + } else { + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + deleteExecutorFromDataStructures(executorId).foreach { pod => + kubernetesClient.pods().delete(pod) + } + } + + def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { + disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) + executorReasonCheckAttemptCounts -= executorId + podsWithKnownExitReasons.remove(executorId) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).orElse { + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + } + + override def start(): Unit = { + super.start() + executorWatchResource.set( + kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .watch(new ExecutorPodsWatcher())) + + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + + if (!Utils.isDynamicAllocationEnabled(conf)) { + doRequestTotalExecutors(initialExecutors) + } + } + + override def stop(): Unit = { + // stop allocation of new resources and caches. + allocatorExecutor.shutdown() + allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) + + // send stop message to executors so they shut down cleanly + super.stop() + + try { + val resource = executorWatchResource.getAndSet(null) + if (resource != null) { + resource.close() + } + } catch { + case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + } + + // then delete the executor pods + Utils.tryLogNonFatalError { + deleteExecutorPodsOnStop() + executorPodsByIPs.clear() + } + Utils.tryLogNonFatalError { + logInfo("Closing kubernetes client") + kubernetesClient.close() + } + } + + /** + * @return A map of K8s cluster nodes to the number of tasks that could benefit from data + * locality if an executor launches on the cluster node. + */ + private def getNodesWithLocalTaskCounts() : Map[String, Int] = { + val nodeToLocalTaskCount = synchronized { + mutable.Map[String, Int]() ++ hostToLocalTaskCount + } + + for (pod <- executorPodsByIPs.values().asScala) { + // Remove cluster nodes that are running our executors already. + // TODO: This prefers spreading out executors across nodes. In case users want + // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut + // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html + nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || + nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || + nodeToLocalTaskCount.remove( + InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + } + nodeToLocalTaskCount.toMap[String, Int] + } + + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { + totalExpectedExecutors.set(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + executorIds.flatMap { executorId => + runningExecutorsToPods.remove(executorId) match { + case Some(pod) => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + Some(pod) + + case None => + logWarning(s"Unable to remove pod for unknown executor $executorId") + None + } + } + } + + kubernetesClient.pods().delete(podsToDelete: _*) + true + } + + private def deleteExecutorPodsOnStop(): Unit = { + val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { + val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) + runningExecutorsToPods.clear() + runningExecutorPodsCopy + } + kubernetesClient.pods().delete(executorPodsToDelete: _*) + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + + private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 + + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + val podIP = pod.getStatus.getPodIP + + action match { + case Action.MODIFIED if (pod.getStatus.getPhase == "Running" + && pod.getMetadata.getDeletionTimestamp == null) => + val clusterNodeName = pod.getSpec.getNodeName + logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") + executorPodsByIPs.put(podIP, pod) + + case Action.DELETED | Action.ERROR => + val executorId = getExecutorId(pod) + logDebug(s"Executor pod $podName at IP $podIP was at $action.") + if (podIP != null) { + executorPodsByIPs.remove(podIP) + } + + val executorExitReason = if (action == Action.ERROR) { + logWarning(s"Received error event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnError(pod) + } else if (action == Action.DELETED) { + logWarning(s"Received delete event of executor pod $podName. Reason: " + + pod.getStatus.getReason) + executorExitReasonOnDelete(pod) + } else { + throw new IllegalStateException( + s"Unknown action that should only be DELETED or ERROR: $action") + } + podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) + + if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { + log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + + s"watch received an event of type $action for this executor. The executor may " + + "have failed to start in the first place and never registered with the driver.") + } + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + + case _ => logDebug(s"Received event of executor pod $podName: " + action) + } + } + + override def onClose(cause: KubernetesClientException): Unit = { + logDebug("Executor pod watch closed.", cause) + } + + private def getExecutorExitStatus(pod: Pod): Int = { + val containerStatuses = pod.getStatus.getContainerStatuses + if (!containerStatuses.isEmpty) { + // we assume the first container represents the pod status. This assumption may not hold + // true in the future. Revisit this if side-car containers start running inside executor + // pods. + getExecutorExitStatus(containerStatuses.get(0)) + } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS + } + + private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { + Option(containerStatus.getState).map { containerState => + Option(containerState.getTerminated).map { containerStateTerminated => + containerStateTerminated.getExitCode.intValue() + }.getOrElse(UNKNOWN_EXIT_CODE) + }.getOrElse(UNKNOWN_EXIT_CODE) + } + + private def isPodAlreadyReleased(pod: Pod): Boolean = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + !runningExecutorsToPods.contains(executorId) + } + } + + private def executorExitReasonOnError(pod: Pod): ExecutorExited = { + val containerExitStatus = getExecutorExitStatus(pod) + // container was probably actively killed by the driver. + if (isPodAlreadyReleased(pod)) { + ExecutorExited(containerExitStatus, exitCausedByApp = false, + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + + "request.") + } else { + val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + + s"exited with exit status code $containerExitStatus." + ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) + } + } + + private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + } + + private def getExecutorId(pod: Pod): String = { + val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) + require(executorId != null, "Unexpected pod metadata; expected all executor pods " + + s"to have label $SPARK_EXECUTOR_ID_LABEL.") + executorId + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new KubernetesDriverEndpoint(rpcEnv, properties) + } + + private class KubernetesDriverEndpoint( + rpcEnv: RpcEnv, + sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + } + } + } + } + } + } +} + +private object KubernetesClusterSchedulerBackend { + private val UNKNOWN_EXIT_CODE = -1 +} diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..ad95fadb7c0c0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala new file mode 100644 index 0000000000000..1c7717c238096 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, _} +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" + private val driverPodUid: String = "driver-uid" + private val executorPrefix: String = "base" + private val executorImage: String = "executor-image" + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .withUid(driverPodUid) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) + .set(EXECUTOR_DOCKER_IMAGE, executorImage) + } + + test("basic executor pod has reasonable defaults") { + val factory = new ExecutorPodFactoryImpl(baseConf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + // The executor pod name and default labels. + assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") + assert(executor.getMetadata.getLabels.size() === 3) + assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getImage === executorImage) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) + assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.getSpec.getNodeSelector.isEmpty) + assert(executor.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor, driverPodUid) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, + "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor, driverPodUid) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> "dummy", + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> "dummy", + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) + val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3febb2f47cfd4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.ThreadUtils + +class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + @Mock + private var successFuture: Future[Boolean] = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + + // Used by the CoarseGrainedSchedulerBackend when making RPC calls. + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(successFuture) + when(successFuture.failed).thenReturn(Future[Throwable] { + // emulate behavior of the Future.failed method. + throw new NoSuchElementException() + }(ThreadUtils.sameThread)) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend() + scheduler.start() + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations).create(secondResolvedPod) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(secondResolvedPod) + verify(podOperations, never()).delete(firstResolvedPod) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(firstResolvedPod) + verify(podOperations, never()).create(secondResolvedPod) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(secondResolvedPod) + } + + test("Scaled down executors should be cleaned up") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + + // The scheduler backend spins up one executor pod. + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + // Request that there are 0 executors and trigger deletion from driver. + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(resolvedPod) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(resolvedPod, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + + // No more deletion attempts of the executors. + // This is graceful termination and should not be detected as a failure. + verify(podOperations, times(1)).delete(resolvedPod) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that fail should not be deleted.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(firstResolvedPod, 1)) + + // A replacement executor should be created but the error pod should persist. + val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(podOperations, never()).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + } + + test("Executors disconnected due to unknown reasons are deleted and replaced.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val executorLostReasonCheckMaxAttempts = sparkConf.get( + KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) + + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to executorLostReasonCheckMaxAttempts foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(firstResolvedPod) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(firstResolvedPod)) + .thenThrow(new RuntimeException("test")) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + test("Executors that are initially created but the watch notices them fail are rebuilt" + + " in the next batch.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).create(firstResolvedPod) + executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) + val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(recreatedResolvedPod) + } + + private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) { + + override def applicationId(): String = APP_ID + } + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(basePod) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { + val resolvedPod = new PodBuilder(expectedPod) + .editMetadata() + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(resolvedPod) + resolvedPod + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7052fb347106b..506adb363aa90 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -41,6 +41,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId +import org.apache.spark.scheduler.cluster.SchedulerBackendUtils import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} /** @@ -109,7 +110,7 @@ private[yarn] class YarnAllocator( sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) @volatile private var targetNumExecutors = - YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) private var currentNodeBlacklist = Set.empty[String] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 3d9f99f57bed7..9c1472cb50e3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -133,8 +133,6 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" - val DEFAULT_NUMBER_EXECUTORS = 2 - // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) @@ -279,27 +277,5 @@ object YarnSparkHadoopUtil { securityMgr.getModifyAclsGroups) ) } - - /** - * Getting the initial target number of executors depends on whether dynamic allocation is - * enabled. - * If not using dynamic allocation it gets the number of executors requested by the user. - */ - def getInitialTargetExecutorNumber( - conf: SparkConf, - numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { - if (Utils.isDynamicAllocationEnabled(conf)) { - val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) - val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) - val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) - require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, - s"initial executor number $initialNumExecutors must between min executor number " + - s"$minNumExecutors and max executor number $maxNumExecutors") - - initialNumExecutors - } else { - conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) - } - } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d482376d14dd7..b722cc401bb73 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -52,7 +52,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray) - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) client = new Client(args, conf) bindToYarn(client.submitApplication(), None) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 4f3d5ebf403e0..e2d477be329c3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClusterSchedulerBackend( val attemptId = ApplicationMaster.getAttemptId bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() - totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) + totalExpectedExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sc.conf) } override def getDriverLogUrls: Option[Map[String, String]] = { From 20b239845b695fe6a893ebfe97b49ef05fae773d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 29 Nov 2017 19:18:47 +0800 Subject: [PATCH 312/712] [SPARK-22605][SQL] SQL write job should also set Spark task output metrics ## What changes were proposed in this pull request? For SQL write jobs, we only set metrics for the SQL listener and display them in the SQL plan UI. We should also set metrics for Spark task output metrics, which will be shown in spark job UI. ## How was this patch tested? test it manually. For a simple write job ``` spark.range(1000).write.parquet("/tmp/p1") ``` now the spark job UI looks like ![ui](https://user-images.githubusercontent.com/3182036/33326478-05a25b7c-d490-11e7-96ef-806117774356.jpg) Author: Wenchen Fan Closes #19833 from cloud-fan/ui. --- .../execution/datasources/BasicWriteStatsTracker.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 11af0aaa7b206..9dbbe9946ee99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -22,7 +22,7 @@ import java.io.FileNotFoundException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution @@ -44,7 +44,6 @@ case class BasicWriteTaskStats( /** * Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]]. - * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) extends WriteTaskStatsTracker with Logging { @@ -106,6 +105,13 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) override def getFinalStats(): WriteTaskStats = { statCurrentFile() + + // Reports bytesWritten and recordsWritten to the Spark output metrics. + Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics => + outputMetrics.setBytesWritten(numBytes) + outputMetrics.setRecordsWritten(numRows) + } + if (submittedFiles != numFiles) { logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + "This could be due to the output format not writing empty files, " + From 57687280d4171db98d4d9404c7bd3374f51deac0 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 29 Nov 2017 09:17:39 -0800 Subject: [PATCH 313/712] [SPARK-22615][SQL] Handle more cases in PropagateEmptyRelation ## What changes were proposed in this pull request? Currently, in the optimize rule `PropagateEmptyRelation`, the following cases is not handled: 1. empty relation as right child in left outer join 2. empty relation as left child in right outer join 3. empty relation as right child in left semi join 4. empty relation as right child in left anti join 5. only one empty relation in full outer join case 1 / 2 / 5 can be treated as **Cartesian product** and cause exception. See the new test cases. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19825 from gengliangwang/SPARK-22615. --- .../optimizer/PropagateEmptyRelation.scala | 36 +++- .../PropagateEmptyRelationSuite.scala | 16 +- .../sql-tests/inputs/join-empty-relation.sql | 28 +++ .../results/join-empty-relation.sql.out | 194 ++++++++++++++++++ 4 files changed, 257 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 52fbb4df2f58e..a6e5aa6daca65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -41,6 +41,10 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming) + // Construct a project list from plan's output, while the value is always NULL. + private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = + plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => empty(p) @@ -49,16 +53,28 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // as stateful streaming joins need to perform other state management operations other than // just processing the input data. case p @ Join(_, _, joinType, _) - if !p.children.exists(_.isStreaming) && p.children.exists(isEmptyLocalRelation) => - joinType match { - case _: InnerLike => empty(p) - // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule. - // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. - case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p) - case RightOuter if isEmptyLocalRelation(p.right) => empty(p) - case FullOuter if p.children.forall(isEmptyLocalRelation) => empty(p) - case _ => p - } + if !p.children.exists(_.isStreaming) => + val isLeftEmpty = isEmptyLocalRelation(p.left) + val isRightEmpty = isEmptyLocalRelation(p.right) + if (isLeftEmpty || isRightEmpty) { + joinType match { + case _: InnerLike => empty(p) + // Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule. + // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. + case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) + case LeftSemi if isRightEmpty => empty(p) + case LeftAnti if isRightEmpty => p.left + case FullOuter if isLeftEmpty && isRightEmpty => empty(p) + case LeftOuter | FullOuter if isRightEmpty => + Project(p.left.output ++ nullValueProjectList(p.right), p.left) + case RightOuter if isRightEmpty => empty(p) + case RightOuter | FullOuter if isLeftEmpty => + Project(nullValueProjectList(p.left) ++ p.right.output, p.right) + case _ => p + } + } else { + p + } case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmptyLocalRelation) => p match { case _: Project => empty(p) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index bc1c48b99c295..3964508e3a55e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.StructType @@ -78,17 +79,18 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, None), + (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, None), - (true, false, LeftAnti, None), - (true, false, LeftSemi, None), + (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftAnti, Some(testRelation1)), + (true, false, LeftSemi, Some(LocalRelation('a.int))), (false, true, Inner, Some(LocalRelation('a.int, 'b.int))), (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), - (false, true, RightOuter, None), - (false, true, FullOuter, None), + (false, true, RightOuter, + Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql new file mode 100644 index 0000000000000..8afa3270f4de4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -0,0 +1,28 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); + +CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false; + +SELECT * FROM t1 INNER JOIN empty_table; +SELECT * FROM t1 CROSS JOIN empty_table; +SELECT * FROM t1 LEFT OUTER JOIN empty_table; +SELECT * FROM t1 RIGHT OUTER JOIN empty_table; +SELECT * FROM t1 FULL OUTER JOIN empty_table; +SELECT * FROM t1 LEFT SEMI JOIN empty_table; +SELECT * FROM t1 LEFT ANTI JOIN empty_table; + +SELECT * FROM empty_table INNER JOIN t1; +SELECT * FROM empty_table CROSS JOIN t1; +SELECT * FROM empty_table LEFT OUTER JOIN t1; +SELECT * FROM empty_table RIGHT OUTER JOIN t1; +SELECT * FROM empty_table FULL OUTER JOIN t1; +SELECT * FROM empty_table LEFT SEMI JOIN t1; +SELECT * FROM empty_table LEFT ANTI JOIN t1; + +SELECT * FROM empty_table INNER JOIN empty_table; +SELECT * FROM empty_table CROSS JOIN empty_table; +SELECT * FROM empty_table LEFT OUTER JOIN empty_table; +SELECT * FROM empty_table RIGHT OUTER JOIN empty_table; +SELECT * FROM empty_table FULL OUTER JOIN empty_table; +SELECT * FROM empty_table LEFT SEMI JOIN empty_table; +SELECT * FROM empty_table LEFT ANTI JOIN empty_table; diff --git a/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out new file mode 100644 index 0000000000000..857073a827f24 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/join-empty-relation.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 24 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW empty_table as SELECT a FROM t2 WHERE false +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT * FROM t1 INNER JOIN empty_table +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +SELECT * FROM t1 CROSS JOIN empty_table +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM t1 LEFT OUTER JOIN empty_table +-- !query 5 schema +struct +-- !query 5 output +1 NULL + + +-- !query 6 +SELECT * FROM t1 RIGHT OUTER JOIN empty_table +-- !query 6 schema +struct +-- !query 6 output + + + +-- !query 7 +SELECT * FROM t1 FULL OUTER JOIN empty_table +-- !query 7 schema +struct +-- !query 7 output +1 NULL + + +-- !query 8 +SELECT * FROM t1 LEFT SEMI JOIN empty_table +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT * FROM t1 LEFT ANTI JOIN empty_table +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT * FROM empty_table INNER JOIN t1 +-- !query 10 schema +struct +-- !query 10 output + + + +-- !query 11 +SELECT * FROM empty_table CROSS JOIN t1 +-- !query 11 schema +struct +-- !query 11 output + + + +-- !query 12 +SELECT * FROM empty_table LEFT OUTER JOIN t1 +-- !query 12 schema +struct +-- !query 12 output + + + +-- !query 13 +SELECT * FROM empty_table RIGHT OUTER JOIN t1 +-- !query 13 schema +struct +-- !query 13 output +NULL 1 + + +-- !query 14 +SELECT * FROM empty_table FULL OUTER JOIN t1 +-- !query 14 schema +struct +-- !query 14 output +NULL 1 + + +-- !query 15 +SELECT * FROM empty_table LEFT SEMI JOIN t1 +-- !query 15 schema +struct +-- !query 15 output + + + +-- !query 16 +SELECT * FROM empty_table LEFT ANTI JOIN t1 +-- !query 16 schema +struct +-- !query 16 output + + + +-- !query 17 +SELECT * FROM empty_table INNER JOIN empty_table +-- !query 17 schema +struct +-- !query 17 output + + + +-- !query 18 +SELECT * FROM empty_table CROSS JOIN empty_table +-- !query 18 schema +struct +-- !query 18 output + + + +-- !query 19 +SELECT * FROM empty_table LEFT OUTER JOIN empty_table +-- !query 19 schema +struct +-- !query 19 output + + + +-- !query 20 +SELECT * FROM empty_table RIGHT OUTER JOIN empty_table +-- !query 20 schema +struct +-- !query 20 output + + + +-- !query 21 +SELECT * FROM empty_table FULL OUTER JOIN empty_table +-- !query 21 schema +struct +-- !query 21 output + + + +-- !query 22 +SELECT * FROM empty_table LEFT SEMI JOIN empty_table +-- !query 22 schema +struct +-- !query 22 output + + + +-- !query 23 +SELECT * FROM empty_table LEFT ANTI JOIN empty_table +-- !query 23 schema +struct +-- !query 23 output + From 284836862b2312aea5d7555c8e3c9d3c4dbc8eaf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Nov 2017 01:19:37 +0800 Subject: [PATCH 314/712] [SPARK-22608][SQL] add new API to CodeGeneration.splitExpressions() ## What changes were proposed in this pull request? This PR adds a new API to ` CodeGenenerator.splitExpression` since since several ` CodeGenenerator.splitExpression` are used with `ctx.INPUT_ROW` to avoid code duplication. ## How was this patch tested? Used existing test suits Author: Kazuaki Ishizaki Closes #19821 from kiszk/SPARK-22608. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 24 ++++++++++++-- .../sql/catalyst/expressions/predicates.scala | 10 +++--- .../expressions/stringExpressions.scala | 32 ++++++------------- 4 files changed, 37 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12baddf1bf7ac..8cafaef61c7d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1040,7 +1040,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ } - val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + val fieldsEvalCodes = if (ctx.currentVars == null) { ctx.splitExpressions( expressions = fieldsEvalCode, funcName = "castStruct", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 668c816b3fd8d..1645db12c53f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -788,11 +788,31 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(expressions: Seq[String]): String = { + splitExpressions(expressions, funcName = "apply", extraArguments = Nil) + } + + /** + * Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name + * and extra arguments. + * + * @param expressions the codes to evaluate expressions. + * @param funcName the split function name base. + * @param extraArguments the list of (type, name) of the arguments of the split function + * except for ctx.INPUT_ROW + */ + def splitExpressions( + expressions: Seq[String], + funcName: String, + extraArguments: Seq[(String, String)]): String = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { - return expressions.mkString("\n") + expressions.mkString("\n") + } else { + splitExpressions( + expressions, + funcName, + arguments = ("InternalRow", INPUT_ROW) +: extraArguments) } - splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index eb7475354b104..1aaaaf1db48d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -251,12 +251,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } """) - val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil - ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args) - } else { - listCode.mkString("\n") - } + val listCodes = ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) ev.copy(code = s""" ${valueGen.code} ${ev.value} = false; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ee5cf925d3cef..34917ace001fa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -73,14 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = inputs, - funcName = "valueConcat", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + val codes = ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcat", + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[${evals.length}]; $codes @@ -156,14 +152,10 @@ case class ConcatWs(children: Seq[Expression]) "" } } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( + val codes = ctx.splitExpressions( expressions = inputs, funcName = "valueConcatWs", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} @@ -1388,14 +1380,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $argList[$index] = $value; """ } - val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = argListCode, - funcName = "valueFormatString", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil) - } else { - argListCode.mkString("\n") - } + val argListCodes = ctx.splitExpressions( + expressions = argListCode, + funcName = "valueFormatString", + extraArguments = ("Object[]", argList) :: Nil) val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName From 193555f79cc73873613674a09a7c371688b6dbc7 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Wed, 29 Nov 2017 14:15:35 -0800 Subject: [PATCH 315/712] [SPARK-18935][MESOS] Fix dynamic reservations on mesos ## What changes were proposed in this pull request? - Solves the issue described in the ticket by preserving reservation and allocation info in all cases (port handling included). - upgrades to 1.4 - Adds extra debug level logging to make debugging easier in the future, for example we add reservation info when applicable. ``` 17/09/29 14:53:07 DEBUG MesosCoarseGrainedSchedulerBackend: Accepting offer: f20de49b-dee3-45dd-a3c1-73418b7de891-O32 with attributes: Map() allocation info: role: "spark-prive" reservation info: name: "ports" type: RANGES ranges { range { begin: 31000 end: 32000 } } role: "spark-prive" reservation { principal: "test" } allocation_info { role: "spark-prive" } ``` - Some style cleanup. ## How was this patch tested? Manually by running the example in the ticket with and without a principal. Specifically I tested it on a dc/os 1.10 cluster with 7 nodes and played with reservations. From the master node in order to reserve resources I executed: ```for i in 0 1 2 3 4 5 6 do curl -i \ -d slaveId=90ec65ea-1f7b-479f-a824-35d2527d6d26-S$i \ -d resources='[ { "name": "cpus", "type": "SCALAR", "scalar": { "value": 2 }, "role": "spark-role", "reservation": { "principal": "" } }, { "name": "mem", "type": "SCALAR", "scalar": { "value": 8026 }, "role": "spark-role", "reservation": { "principal": "" } } ]' \ -X POST http://master.mesos:5050/master/reserve done ``` Nodes had 4 cpus (m3.xlarge instances) and I reserved either 2 or 4 cpus (all for a role). I verified it launches tasks on nodes with reserved resources under `spark-role` role only if a) there are remaining resources for (*) default role and the spark driver has no role assigned to it. b) the spark driver has a role assigned to it and it is the same role used in reservations. I also tested this locally on my machine. Author: Stavros Kontopoulos Closes #19390 from skonto/fix_dynamic_reservation. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- resource-managers/mesos/pom.xml | 2 +- .../cluster/mesos/MesosClusterScheduler.scala | 1 - .../MesosCoarseGrainedSchedulerBackend.scala | 17 +++- .../cluster/mesos/MesosSchedulerUtils.scala | 99 ++++++++++++------- 6 files changed, 80 insertions(+), 43 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 21c8a75796387..50ac6d139bbd4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar -mesos-1.3.0-shaded-protobuf.jar +mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 7173426c7bf74..1b1e3166d53db 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar -mesos-1.3.0-shaded-protobuf.jar +mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index de8f1c913651d..70d0c1750b14e 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -29,7 +29,7 @@ Spark Project Mesos mesos - 1.3.0 + 1.4.0 shaded-protobuf diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index c41283e4a3e39..d224a7325820a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -36,7 +36,6 @@ import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionRes import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils - /** * Tracks the current state of a Mesos Task that runs a Spark driver. * @param driverDescription Submitted driver description from diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index c392061fdb358..191415a2578b2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -400,13 +400,20 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerMem = getResource(offer.getResourcesList, "mem") val offerCpus = getResource(offer.getResourcesList, "cpus") val offerPorts = getRangeResource(offer.getResourcesList, "ports") + val offerReservationInfo = offer + .getResourcesList + .asScala + .find { r => r.getReservation != null } val id = offer.getId.getValue if (tasks.contains(offer.getId)) { // accept val offerTasks = tasks(offer.getId) logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + - s"mem: $offerMem cpu: $offerCpus ports: $offerPorts." + + offerReservationInfo.map(resInfo => + s"reservation info: ${resInfo.getReservation.toString}").getOrElse("") + + s"mem: $offerMem cpu: $offerCpus ports: $offerPorts " + + s"resources: ${offer.getResourcesList.asScala.mkString(",")}." + s" Launching ${offerTasks.size} Mesos tasks.") for (task <- offerTasks) { @@ -416,7 +423,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val ports = getRangeResource(task.getResourcesList, "ports").mkString(",") logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus" + - s" ports: $ports") + s" ports: $ports" + s" on slave with slave id: ${task.getSlaveId.getValue} ") } driver.launchTasks( @@ -431,7 +438,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } else { declineOffer( driver, - offer) + offer, + Some("Offer was declined due to unmet task launch constraints.")) } } } @@ -513,6 +521,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( totalGpusAcquired += taskGPUs gpusByTaskId(taskId) = taskGPUs } + } else { + logDebug(s"Cannot launch a task for offer with id: $offerId on slave " + + s"with id: $slaveId. Requirements were not met for this offer.") } } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 6fcb30af8a733..e75450369ad85 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -28,7 +28,8 @@ import com.google.common.base.Splitter import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability -import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.mesos.Protos.Resource.ReservationInfo +import org.apache.mesos.protobuf.{ByteString, GeneratedMessageV3} import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.TaskState @@ -36,8 +37,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.util.Utils - - /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper * methods and Mesos scheduler will use. @@ -46,6 +45,8 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) + private final val ANY_ROLE = "*" + /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -175,17 +176,36 @@ trait MesosSchedulerUtils extends Logging { registerLatch.countDown() } - def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + private def setReservationInfo( + reservationInfo: Option[ReservationInfo], + role: Option[String], + builder: Resource.Builder): Unit = { + if (!role.contains(ANY_ROLE)) { + reservationInfo.foreach { res => builder.setReservation(res) } + } + } + + def createResource( + name: String, + amount: Double, + role: Option[String] = None, + reservationInfo: Option[ReservationInfo] = None): Resource = { val builder = Resource.newBuilder() .setName(name) .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) - role.foreach { r => builder.setRole(r) } - + setReservationInfo(reservationInfo, role, builder) builder.build() } + private def getReservation(resource: Resource): Option[ReservationInfo] = { + if (resource.hasReservation) { + Some(resource.getReservation) + } else { + None + } + } /** * Partition the existing set of resources into two groups, those remaining to be * scheduled and those requested to be used for a new task. @@ -203,14 +223,17 @@ trait MesosSchedulerUtils extends Logging { var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { case r => + val reservation = getReservation(r) if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && r.getName == resourceName) { val usage = Math.min(remain, r.getScalar.getValue) - requestedResources += createResource(resourceName, usage, Some(r.getRole)) + requestedResources += createResource(resourceName, usage, + Option(r.getRole), reservation) remain -= usage - createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + createResource(resourceName, r.getScalar.getValue - usage, + Option(r.getRole), reservation) } else { r } @@ -228,16 +251,6 @@ trait MesosSchedulerUtils extends Logging { (attr.getName, attr.getText.getValue.split(',').toSet) } - - /** Build a Mesos resource protobuf object */ - protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } - /** * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are @@ -245,7 +258,8 @@ trait MesosSchedulerUtils extends Logging { * @param offerAttributes the attributes offered * @return */ - protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + protected def toAttributeMap(offerAttributes: JList[Attribute]) + : Map[String, GeneratedMessageV3] = { offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar @@ -266,7 +280,7 @@ trait MesosSchedulerUtils extends Logging { */ def matchesAttributeRequirements( slaveOfferConstraints: Map[String, Set[String]], - offerAttributes: Map[String, GeneratedMessage]): Boolean = { + offerAttributes: Map[String, GeneratedMessageV3]): Boolean = { slaveOfferConstraints.forall { // offer has the required attribute and subsumes the required values for that attribute case (name, requiredValues) => @@ -427,10 +441,10 @@ trait MesosSchedulerUtils extends Logging { // partition port offers val (resourcesWithoutPorts, portResources) = filterPortResources(offeredResources) - val portsAndRoles = requestedPorts. - map(x => (x, findPortAndGetAssignedRangeRole(x, portResources))) + val portsAndResourceInfo = requestedPorts. + map { x => (x, findPortAndGetAssignedResourceInfo(x, portResources)) } - val assignedPortResources = createResourcesFromPorts(portsAndRoles) + val assignedPortResources = createResourcesFromPorts(portsAndResourceInfo) // ignore non-assigned port resources, they will be declined implicitly by mesos // no need for splitting port resources. @@ -450,16 +464,25 @@ trait MesosSchedulerUtils extends Logging { managedPortNames.map(conf.getLong(_, 0)).filter( _ != 0) } + private case class RoleResourceInfo( + role: String, + resInfo: Option[ReservationInfo]) + /** Creates a mesos resource for a specific port number. */ - private def createResourcesFromPorts(portsAndRoles: List[(Long, String)]) : List[Resource] = { - portsAndRoles.flatMap{ case (port, role) => - createMesosPortResource(List((port, port)), Some(role))} + private def createResourcesFromPorts( + portsAndResourcesInfo: List[(Long, RoleResourceInfo)]) + : List[Resource] = { + portsAndResourcesInfo.flatMap { case (port, rInfo) => + createMesosPortResource(List((port, port)), Option(rInfo.role), rInfo.resInfo)} } /** Helper to create mesos resources for specific port ranges. */ private def createMesosPortResource( ranges: List[(Long, Long)], - role: Option[String] = None): List[Resource] = { + role: Option[String] = None, + reservationInfo: Option[ReservationInfo] = None): List[Resource] = { + // for ranges we are going to use (user defined ports fall in there) create mesos resources + // for each range there is a role associated with it. ranges.map { case (rangeStart, rangeEnd) => val rangeValue = Value.Range.newBuilder() .setBegin(rangeStart) @@ -468,7 +491,8 @@ trait MesosSchedulerUtils extends Logging { .setName("ports") .setType(Value.Type.RANGES) .setRanges(Value.Ranges.newBuilder().addRange(rangeValue)) - role.foreach(r => builder.setRole(r)) + role.foreach { r => builder.setRole(r) } + setReservationInfo(reservationInfo, role, builder) builder.build() } } @@ -477,19 +501,21 @@ trait MesosSchedulerUtils extends Logging { * Helper to assign a port to an offered range and get the latter's role * info to use it later on. */ - private def findPortAndGetAssignedRangeRole(port: Long, portResources: List[Resource]) - : String = { + private def findPortAndGetAssignedResourceInfo(port: Long, portResources: List[Resource]) + : RoleResourceInfo = { val ranges = portResources. - map(resource => - (resource.getRole, resource.getRanges.getRangeList.asScala - .map(r => (r.getBegin, r.getEnd)).toList)) + map { resource => + val reservation = getReservation(resource) + (RoleResourceInfo(resource.getRole, reservation), + resource.getRanges.getRangeList.asScala.map(r => (r.getBegin, r.getEnd)).toList) + } - val rangePortRole = ranges - .find { case (role, rangeList) => rangeList + val rangePortResourceInfo = ranges + .find { case (resourceInfo, rangeList) => rangeList .exists{ case (rangeStart, rangeEnd) => rangeStart <= port & rangeEnd >= port}} // this is safe since we have previously checked about the ranges (see checkPorts method) - rangePortRole.map{ case (role, rangeList) => role}.get + rangePortResourceInfo.map{ case (resourceInfo, rangeList) => resourceInfo}.get } /** Retrieves the port resources from a list of mesos offered resources */ @@ -564,3 +590,4 @@ trait MesosSchedulerUtils extends Logging { } } } + From 8ff474f6e543203fac5d49af7fbe98a8a98da567 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 29 Nov 2017 14:34:41 -0800 Subject: [PATCH 316/712] [SPARK-20650][CORE] Remove JobProgressListener. The only remaining use of this class was the SparkStatusTracker, which was modified to use the new status store. The test code to wait for executors was moved to TestUtils and now uses the SparkStatusTracker API. Indirectly, ConsoleProgressBar also uses this data. Because it has some lower latency requirements, a shortcut to efficiently get the active stages from the active listener was added to the AppStateStore. Now that all UI code goes through the status store to get its data, the FsHistoryProvider can be cleaned up to only replay event logs when needed - that is, when there is no pre-existing disk store for the application. As part of this change I also modified the streaming UI to read the needed data from the store, which was missed in the previous patch that made JobProgressListener redundant. Author: Marcelo Vanzin Closes #19750 from vanzin/SPARK-20650. --- .../scala/org/apache/spark/SparkContext.scala | 11 +- .../org/apache/spark/SparkStatusTracker.scala | 76 +-- .../scala/org/apache/spark/TestUtils.scala | 26 +- .../deploy/history/FsHistoryProvider.scala | 65 +- .../spark/status/AppStatusListener.scala | 51 +- .../apache/spark/status/AppStatusStore.scala | 17 +- .../org/apache/spark/status/LiveEntity.scala | 8 +- .../spark/status/api/v1/StagesResource.scala | 1 - .../apache/spark/ui/ConsoleProgressBar.scala | 18 +- .../spark/ui/jobs/JobProgressListener.scala | 612 ------------------ .../org/apache/spark/ui/jobs/StagePage.scala | 1 - .../org/apache/spark/ui/jobs/UIData.scala | 311 --------- .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../org/apache/spark/StatusTrackerSuite.scala | 6 +- .../spark/broadcast/BroadcastSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../SparkListenerWithClusterSuite.scala | 4 +- .../ui/jobs/JobProgressListenerSuite.scala | 442 ------------- project/MimaExcludes.scala | 2 + .../apache/spark/streaming/ui/BatchPage.scala | 75 ++- 21 files changed, 208 insertions(+), 1528 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala delete mode 100644 core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 23fd54f59268a..984dd0a6629a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -58,7 +58,6 @@ import org.apache.spark.status.{AppStatusPlugin, AppStatusStore} import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} -import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ /** @@ -195,7 +194,6 @@ class SparkContext(config: SparkConf) extends Logging { private var _eventLogCodec: Option[String] = None private var _listenerBus: LiveListenerBus = _ private var _env: SparkEnv = _ - private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ private var _progressBar: Option[ConsoleProgressBar] = None private var _ui: Option[SparkUI] = None @@ -270,8 +268,6 @@ class SparkContext(config: SparkConf) extends Logging { val map: ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]() map.asScala } - private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener - def statusTracker: SparkStatusTracker = _statusTracker private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar @@ -421,11 +417,6 @@ class SparkContext(config: SparkConf) extends Logging { _listenerBus = new LiveListenerBus(_conf) - // "_jobProgressListener" should be set up before creating SparkEnv because when creating - // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. - _jobProgressListener = new JobProgressListener(_conf) - listenerBus.addToStatusQueue(jobProgressListener) - // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. _statusStore = AppStatusStore.createLiveStore(conf, l => listenerBus.addToStatusQueue(l)) @@ -440,7 +431,7 @@ class SparkContext(config: SparkConf) extends Logging { _conf.set("spark.repl.class.uri", replUri) } - _statusTracker = new SparkStatusTracker(this) + _statusTracker = new SparkStatusTracker(this, _statusStore) _progressBar = if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 22a553e68439a..70865cb58c571 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -17,7 +17,10 @@ package org.apache.spark -import org.apache.spark.scheduler.TaskSchedulerImpl +import java.util.Arrays + +import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.StageStatus /** * Low-level status reporting APIs for monitoring job and stage progress. @@ -33,9 +36,7 @@ import org.apache.spark.scheduler.TaskSchedulerImpl * * NOTE: this class's constructor should be considered private and may be subject to change. */ -class SparkStatusTracker private[spark] (sc: SparkContext) { - - private val jobProgressListener = sc.jobProgressListener +class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore) { /** * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then @@ -46,9 +47,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * its result. */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray - } + val expected = Option(jobGroup) + store.jobsList(null).filter(_.jobGroup == expected).map(_.jobId).toArray } /** @@ -57,9 +57,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * This method does not guarantee the order of the elements in its result. */ def getActiveStageIds(): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.activeStages.values.map(_.stageId).toArray - } + store.stageList(Arrays.asList(StageStatus.ACTIVE)).map(_.stageId).toArray } /** @@ -68,19 +66,15 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * This method does not guarantee the order of the elements in its result. */ def getActiveJobIds(): Array[Int] = { - jobProgressListener.synchronized { - jobProgressListener.activeJobs.values.map(_.jobId).toArray - } + store.jobsList(Arrays.asList(JobExecutionStatus.RUNNING)).map(_.jobId).toArray } /** * Returns job information, or `None` if the job info could not be found or was garbage collected. */ def getJobInfo(jobId: Int): Option[SparkJobInfo] = { - jobProgressListener.synchronized { - jobProgressListener.jobIdToData.get(jobId).map { data => - new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) - } + store.asOption(store.job(jobId)).map { job => + new SparkJobInfoImpl(jobId, job.stageIds.toArray, job.status) } } @@ -89,21 +83,16 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * garbage collected. */ def getStageInfo(stageId: Int): Option[SparkStageInfo] = { - jobProgressListener.synchronized { - for ( - info <- jobProgressListener.stageIdToInfo.get(stageId); - data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) - ) yield { - new SparkStageInfoImpl( - stageId, - info.attemptId, - info.submissionTime.getOrElse(0), - info.name, - info.numTasks, - data.numActiveTasks, - data.numCompleteTasks, - data.numFailedTasks) - } + store.asOption(store.lastStageAttempt(stageId)).map { stage => + new SparkStageInfoImpl( + stageId, + stage.attemptId, + stage.submissionTime.map(_.getTime()).getOrElse(0L), + stage.name, + stage.numTasks, + stage.numActiveTasks, + stage.numCompleteTasks, + stage.numFailedTasks) } } @@ -111,17 +100,20 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. */ def getExecutorInfos: Array[SparkExecutorInfo] = { - val executorIdToRunningTasks: Map[String, Int] = - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors + store.executorList(true).map { exec => + val (host, port) = exec.hostPort.split(":", 2) match { + case Array(h, p) => (h, p.toInt) + case Array(h) => (h, -1) + } + val cachedMem = exec.memoryMetrics.map { mem => + mem.usedOnHeapStorageMemory + mem.usedOffHeapStorageMemory + }.getOrElse(0L) - sc.getExecutorStorageStatus.map { status => - val bmId = status.blockManagerId new SparkExecutorInfoImpl( - bmId.host, - bmId.port, - status.cacheSize, - executorIdToRunningTasks.getOrElse(bmId.executorId, 0) - ) - } + host, + port, + cachedMem, + exec.activeTasks) + }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a80016dd22fc5..93e7ee3d2a404 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.Arrays -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -232,6 +232,30 @@ private[spark] object TestUtils { } } + /** + * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting + * time elapsed before `numExecutors` executors up. Exposed for testing. + * + * @param numExecutors the number of executors to wait at least + * @param timeout time to wait in milliseconds + */ + private[spark] def waitUntilExecutorsUp( + sc: SparkContext, + numExecutors: Int, + timeout: Long): Unit = { + val finishTime = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeout) + while (System.nanoTime() < finishTime) { + if (sc.statusTracker.getExecutorInfos.length > numExecutors) { + return + } + // Sleep rather than using wait/notify, because this is used only for testing and wait/notify + // add overhead in the general case. + Thread.sleep(10) + } + throw new TimeoutException( + s"Can't find $numExecutors executors before $timeout milliseconds elapsed") + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 69ccde3a8149d..6a83c106f6d84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -299,8 +299,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.adminAclsGroups.getOrElse("")) secManager.setViewAclsGroups(attempt.viewAclsGroups.getOrElse("")) - val replayBus = new ReplayListenerBus() - val uiStorePath = storePath.map { path => getStorePath(path, appId, attemptId) } val (kvstore, needReplay) = uiStorePath match { @@ -320,48 +318,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) (new InMemoryStore(), true) } - val listener = if (needReplay) { - val _listener = new AppStatusListener(kvstore, conf, false, + if (needReplay) { + val replayBus = new ReplayListenerBus() + val listener = new AppStatusListener(kvstore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) - replayBus.addListener(_listener) + replayBus.addListener(listener) AppStatusPlugin.loadPlugins().foreach { plugin => plugin.setupListeners(conf, kvstore, l => replayBus.addListener(l), false) } - Some(_listener) - } else { - None + try { + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + listener.flush() + } catch { + case e: Exception => + try { + kvstore.close() + } catch { + case _e: Exception => logInfo("Error closing store.", _e) + } + uiStorePath.foreach(Utils.deleteRecursively) + if (e.isInstanceOf[FileNotFoundException]) { + return None + } else { + throw e + } + } } - val loadedUI = { - val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, - HistoryServer.getAttemptURI(appId, attempt.info.attemptId), - attempt.info.startTime.getTime(), - attempt.info.appSparkVersion) - LoadedAppUI(ui) + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, secManager, app.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + attempt.info.startTime.getTime(), + attempt.info.appSparkVersion) + AppStatusPlugin.loadPlugins().foreach { plugin => + plugin.setupUI(ui) } - try { - AppStatusPlugin.loadPlugins().foreach { plugin => - plugin.setupUI(loadedUI.ui) - } - - val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) - replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - listener.foreach(_.flush()) - } catch { - case e: Exception => - try { - kvstore.close() - } catch { - case _e: Exception => logInfo("Error closing store.", _e) - } - uiStorePath.foreach(Utils.deleteRecursively) - if (e.isInstanceOf[FileNotFoundException]) { - return None - } else { - throw e - } - } + val loadedUI = LoadedAppUI(ui) synchronized { activeUIs((appId, attemptId)) = loadedUI diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f2d8e0a5480ba..9c23d9d8c923a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -18,7 +18,10 @@ package org.apache.spark.status import java.util.Date +import java.util.concurrent.ConcurrentHashMap +import java.util.function.Function +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import org.apache.spark._ @@ -59,7 +62,7 @@ private[spark] class AppStatusListener( // Keep track of live entities, so that task metrics can be efficiently updated (without // causing too many writes to the underlying store, and other expensive operations). - private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() private val liveExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() @@ -268,13 +271,15 @@ private[spark] class AppStatusListener( val now = System.nanoTime() // Check if there are any pending stages that match this job; mark those as skipped. - job.stageIds.foreach { sid => - val pending = liveStages.filter { case ((id, _), _) => id == sid } - pending.foreach { case (key, stage) => + val it = liveStages.entrySet.iterator() + while (it.hasNext()) { + val e = it.next() + if (job.stageIds.contains(e.getKey()._1)) { + val stage = e.getValue() stage.status = v1.StageStatus.SKIPPED job.skippedStages += stage.info.stageId job.skippedTasks += stage.info.numTasks - liveStages.remove(key) + it.remove() update(stage, now) } } @@ -336,7 +341,7 @@ private[spark] class AppStatusListener( liveTasks.put(event.taskInfo.taskId, task) liveUpdate(task, now) - liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) maybeUpdate(stage, now) @@ -403,7 +408,7 @@ private[spark] class AppStatusListener( (0, 1, 0) } - liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { stage.metrics.update(metricsDelta) } @@ -466,12 +471,19 @@ private[spark] class AppStatusListener( } } - maybeUpdate(exec, now) + // Force an update on live applications when the number of active tasks reaches 0. This is + // checked in some tests (e.g. SQLTestUtilsBase) so it needs to be reliably up to date. + if (exec.activeTasks == 0) { + liveUpdate(exec, now) + } else { + maybeUpdate(exec, now) + } } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -540,7 +552,7 @@ private[spark] class AppStatusListener( val delta = task.updateMetrics(metrics) maybeUpdate(task, now) - liveStages.get((sid, sAttempt)).foreach { stage => + Option(liveStages.get((sid, sAttempt))).foreach { stage => stage.metrics.update(delta) maybeUpdate(stage, now) @@ -563,7 +575,7 @@ private[spark] class AppStatusListener( /** Flush all live entities' data to the underlying store. */ def flush(): Unit = { val now = System.nanoTime() - liveStages.values.foreach { stage => + liveStages.values.asScala.foreach { stage => update(stage, now) stage.executorSummaries.values.foreach(update(_, now)) } @@ -574,6 +586,18 @@ private[spark] class AppStatusListener( pools.values.foreach(update(_, now)) } + /** + * Shortcut to get active stages quickly in a live application, for use by the console + * progress bar. + */ + def activeStages(): Seq[v1.StageData] = { + liveStages.values.asScala + .filter(_.info.submissionTime.isDefined) + .map(_.toApi()) + .toList + .sortBy(_.stageId) + } + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId @@ -708,7 +732,10 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + new Function[(Int, Int), LiveStage]() { + override def apply(key: (Int, Int)): LiveStage = new LiveStage() + }) stage.info = info stage } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index d0615e5dd0223..22d768b3cb990 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -32,7 +32,9 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** * A wrapper around a KVStore that provides methods for accessing the API data stored within. */ -private[spark] class AppStatusStore(val store: KVStore) { +private[spark] class AppStatusStore( + val store: KVStore, + listener: Option[AppStatusListener] = None) { def applicationInfo(): v1.ApplicationInfo = { store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info @@ -70,6 +72,14 @@ private[spark] class AppStatusStore(val store: KVStore) { store.read(classOf[ExecutorSummaryWrapper], executorId).info } + /** + * This is used by ConsoleProgressBar to quickly fetch active stages for drawing the progress + * bar. It will only return anything useful when called from a live application. + */ + def activeStages(): Seq[v1.StageData] = { + listener.map(_.activeStages()).getOrElse(Nil) + } + def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { val it = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) if (statuses != null && !statuses.isEmpty()) { @@ -338,11 +348,12 @@ private[spark] object AppStatusStore { */ def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { val store = new InMemoryStore() - addListenerFn(new AppStatusListener(store, conf, true)) + val listener = new AppStatusListener(store, conf, true) + addListenerFn(listener) AppStatusPlugin.loadPlugins().foreach { p => p.setupListeners(conf, store, addListenerFn, true) } - new AppStatusStore(store) + new AppStatusStore(store, listener = Some(listener)) } } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index ef2936c9b69a4..983c58a607aa8 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -408,8 +408,8 @@ private class LiveStage extends LiveEntity { new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) } - override protected def doUpdate(): Any = { - val update = new v1.StageData( + def toApi(): v1.StageData = { + new v1.StageData( status, info.stageId, info.attemptId, @@ -449,8 +449,10 @@ private class LiveStage extends LiveEntity { None, None, killedSummary) + } - new StageDataWrapper(update, jobIds) + override protected def doUpdate(): Any = { + new StageDataWrapper(toApi(), jobIds) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index bd4dfe3c68885..b3561109bc636 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -25,7 +25,6 @@ import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData.StageUIData @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 3ae80ecfd22e6..3c4ee4eb6bbb9 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -21,10 +21,11 @@ import java.util.{Timer, TimerTask} import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.status.api.v1.StageData /** * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the - * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed + * status of active stages from the app state store periodically, the progress bar will be showed * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status * of them will be combined together, showed in one line. */ @@ -64,9 +65,8 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { if (now - lastFinishTime < firstDelayMSec) { return } - val stageIds = sc.statusTracker.getActiveStageIds() - val stages = stageIds.flatMap(sc.statusTracker.getStageInfo).filter(_.numTasks() > 1) - .filter(now - _.submissionTime() > firstDelayMSec).sortBy(_.stageId()) + val stages = sc.statusStore.activeStages() + .filter { s => now - s.submissionTime.get.getTime() > firstDelayMSec } if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } @@ -77,15 +77,15 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { * after your last output, keeps overwriting itself to hold in one line. The logging will follow * the progress bar, then progress bar will be showed in next line without overwrite logs. */ - private def show(now: Long, stages: Seq[SparkStageInfo]) { + private def show(now: Long, stages: Seq[StageData]) { val width = TerminalWidth / stages.size val bar = stages.map { s => - val total = s.numTasks() - val header = s"[Stage ${s.stageId()}:" - val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" + val total = s.numTasks + val header = s"[Stage ${s.stageId}:" + val tailer = s"(${s.numCompleteTasks} + ${s.numActiveTasks}) / $total]" val w = width - header.length - tailer.length val bar = if (w > 0) { - val percent = w * s.numCompletedTasks() / total + val percent = w * s.numCompleteTasks / total (0 until w).map { i => if (i < percent) "=" else if (i == percent) ">" else " " }.mkString("") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala deleted file mode 100644 index a18e86ec0a73b..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ /dev/null @@ -1,612 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import java.util.concurrent.TimeoutException - -import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap, ListBuffer} - -import org.apache.spark._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.ui.SparkUI -import org.apache.spark.ui.jobs.UIData._ - -/** - * :: DeveloperApi :: - * Tracks task-level information to be displayed in the UI. - * - * All access to the data structures in this class must be synchronized on the - * class, since the UI thread and the EventBus loop may otherwise be reading and - * updating the internal data structures concurrently. - */ -@DeveloperApi -@deprecated("This class will be removed in a future release.", "2.2.0") -class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { - - // Define a handful of type aliases so that data structures' types can serve as documentation. - // These type aliases are public because they're used in the types of public fields: - - type JobId = Int - type JobGroupId = String - type StageId = Int - type StageAttemptId = Int - type PoolName = String - type ExecutorId = String - - // Application: - @volatile var startTime = -1L - @volatile var endTime = -1L - - // Jobs: - val activeJobs = new HashMap[JobId, JobUIData] - val completedJobs = ListBuffer[JobUIData]() - val failedJobs = ListBuffer[JobUIData]() - val jobIdToData = new HashMap[JobId, JobUIData] - val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] - - // Stages: - val pendingStages = new HashMap[StageId, StageInfo] - val activeStages = new HashMap[StageId, StageInfo] - val completedStages = ListBuffer[StageInfo]() - val skippedStages = ListBuffer[StageInfo]() - val failedStages = ListBuffer[StageInfo]() - val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] - val stageIdToInfo = new HashMap[StageId, StageInfo] - val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] - val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() - // Total of completed and failed stages that have ever been run. These may be greater than - // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than - // JobProgressListener's retention limits. - var numCompletedStages = 0 - var numFailedStages = 0 - var numCompletedJobs = 0 - var numFailedJobs = 0 - - // Misc: - val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - - def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq - - var schedulingMode: Option[SchedulingMode] = None - - // To limit the total memory usage of JobProgressListener, we only track information for a fixed - // number of non-active jobs and stages (there is no limit for active jobs and stages): - - val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) - val retainedJobs = conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) - val retainedTasks = conf.get(UI_RETAINED_TASKS) - - // We can test for memory leaks by ensuring that collections that track non-active jobs and - // stages do not grow without bound and that collections for active jobs/stages eventually become - // empty once Spark is idle. Let's partition our collections into ones that should be empty - // once Spark is idle and ones that should have a hard- or soft-limited sizes. - // These methods are used by unit tests, but they're defined here so that people don't forget to - // update the tests when adding new collections. Some collections have multiple levels of - // nesting, etc, so this lets us customize our notion of "size" for each structure: - - // These collections should all be empty once Spark is idle (no active stages / jobs): - private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = { - Map( - "activeStages" -> activeStages.size, - "activeJobs" -> activeJobs.size, - "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, - "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum - ) - } - - // These collections should stop growing once we have run at least `spark.ui.retainedStages` - // stages and `spark.ui.retainedJobs` jobs: - private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = { - Map( - "completedJobs" -> completedJobs.size, - "failedJobs" -> failedJobs.size, - "completedStages" -> completedStages.size, - "skippedStages" -> skippedStages.size, - "failedStages" -> failedStages.size - ) - } - - // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to - // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: - private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { - Map( - "jobIdToData" -> jobIdToData.size, - "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size, - "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, - // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: - "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size - ) - } - - /** If stages is too large, remove and garbage collect old stages */ - private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { - if (stages.size > retainedStages) { - val toRemove = calculateNumberToRemove(stages.size, retainedStages) - stages.take(toRemove).foreach { s => - stageIdToData.remove((s.stageId, s.attemptId)) - stageIdToInfo.remove(s.stageId) - } - stages.trimStart(toRemove) - } - } - - /** If jobs is too large, remove and garbage collect old jobs */ - private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { - if (jobs.size > retainedJobs) { - val toRemove = calculateNumberToRemove(jobs.size, retainedJobs) - jobs.take(toRemove).foreach { job => - // Remove the job's UI data, if it exists - jobIdToData.remove(job.jobId).foreach { removedJob => - // A null jobGroupId is used for jobs that are run without a job group - val jobGroupId = removedJob.jobGroup.orNull - // Remove the job group -> job mapping entry, if it exists - jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => - jobsInGroup.remove(job.jobId) - // If this was the last job in this job group, remove the map entry for the job group - if (jobsInGroup.isEmpty) { - jobGroupToJobIds.remove(jobGroupId) - } - } - } - } - jobs.trimStart(toRemove) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { - val jobGroup = for ( - props <- Option(jobStart.properties); - group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) yield group - val jobData: JobUIData = - new JobUIData( - jobId = jobStart.jobId, - submissionTime = Option(jobStart.time).filter(_ >= 0), - stageIds = jobStart.stageIds, - jobGroup = jobGroup, - status = JobExecutionStatus.RUNNING) - // A null jobGroupId is used for jobs that are run without a job group - jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) - jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) - // Compute (a potential underestimate of) the number of tasks that will be run by this job. - // This may be an underestimate because the job start event references all of the result - // stages' transitive stage dependencies, but some of these stages might be skipped if their - // output is available from earlier runs. - // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. - jobData.numTasks = { - val allStages = jobStart.stageInfos - val missingStages = allStages.filter(_.completionTime.isEmpty) - missingStages.map(_.numTasks).sum - } - jobIdToData(jobStart.jobId) = jobData - activeJobs(jobStart.jobId) = jobData - for (stageId <- jobStart.stageIds) { - stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) - } - // If there's no information for a stage, store the StageInfo received from the scheduler - // so that we can display stage descriptions for pending stages: - for (stageInfo <- jobStart.stageInfos) { - stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) - stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { - logWarning(s"Job completed for unknown job ${jobEnd.jobId}") - new JobUIData(jobId = jobEnd.jobId) - } - jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) - - jobData.stageIds.foreach(pendingStages.remove) - jobEnd.jobResult match { - case JobSucceeded => - completedJobs += jobData - trimJobsIfNecessary(completedJobs) - jobData.status = JobExecutionStatus.SUCCEEDED - numCompletedJobs += 1 - case JobFailed(_) => - failedJobs += jobData - trimJobsIfNecessary(failedJobs) - jobData.status = JobExecutionStatus.FAILED - numFailedJobs += 1 - } - for (stageId <- jobData.stageIds) { - stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => - jobsUsingStage.remove(jobEnd.jobId) - if (jobsUsingStage.isEmpty) { - stageIdToActiveJobIds.remove(stageId) - } - stageIdToInfo.get(stageId).foreach { stageInfo => - if (stageInfo.submissionTime.isEmpty) { - // if this stage is pending, it won't complete, so mark it as "skipped": - skippedStages += stageInfo - trimStagesIfNecessary(skippedStages) - jobData.numSkippedStages += 1 - jobData.numSkippedTasks += stageInfo.numTasks - } - } - } - } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - val stage = stageCompleted.stageInfo - stageIdToInfo(stage.stageId) = stage - val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { - logWarning("Stage completed for unknown stage " + stage.stageId) - new StageUIData - }) - - for ((id, info) <- stageCompleted.stageInfo.accumulables) { - stageData.accumulables(id) = info - } - - poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap => - hashMap.remove(stage.stageId) - } - activeStages.remove(stage.stageId) - if (stage.failureReason.isEmpty) { - completedStages += stage - numCompletedStages += 1 - trimStagesIfNecessary(completedStages) - } else { - failedStages += stage - numFailedStages += 1 - trimStagesIfNecessary(failedStages) - } - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveStages -= 1 - if (stage.failureReason.isEmpty) { - if (stage.submissionTime.isDefined) { - jobData.completedStageIndices.add(stage.stageId) - } - } else { - jobData.numFailedStages += 1 - } - } - } - - /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val stage = stageSubmitted.stageInfo - activeStages(stage.stageId) = stage - pendingStages.remove(stage.stageId) - val poolName = Option(stageSubmitted.properties).map { - p => p.getProperty("spark.scheduler.pool", SparkUI.DEFAULT_POOL_NAME) - }.getOrElse(SparkUI.DEFAULT_POOL_NAME) - - stageIdToInfo(stage.stageId) = stage - val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) - stageData.schedulingPool = poolName - - stageData.description = Option(stageSubmitted.properties).flatMap { - p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) - } - - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) - stages(stage.stageId) = stage - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveStages += 1 - - // If a stage retries again, it should be removed from completedStageIndices set - jobData.completedStageIndices.remove(stage.stageId) - } - } - - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { - val taskInfo = taskStart.taskInfo - if (taskInfo != null) { - val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { - logWarning("Task start for unknown stage " + taskStart.stageId) - new StageUIData - }) - stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo)) - } - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveTasks += 1 - } - } - - override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { - // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in - // stageToTaskInfos already has the updated status. - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val info = taskEnd.taskInfo - // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task - // completion event is for. Let's just drop it here. This means we might have some speculation - // tasks on the web ui that's never marked as complete. - if (info != null && taskEnd.stageAttemptId != -1) { - val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { - logWarning("Task end for unknown stage " + taskEnd.stageId) - new StageUIData - }) - - for (accumulableInfo <- info.accumulables) { - stageData.accumulables(accumulableInfo.id) = accumulableInfo - } - - val execSummaryMap = stageData.executorSummary - val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) - - taskEnd.reason match { - case Success => - execSummary.succeededTasks += 1 - case kill: TaskKilled => - execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( - kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - case commitDenied: TaskCommitDenied => - execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( - commitDenied.toErrorString, execSummary.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - case _ => - execSummary.failedTasks += 1 - } - execSummary.taskTime += info.duration - stageData.numActiveTasks -= 1 - - val errorMessage: Option[String] = - taskEnd.reason match { - case org.apache.spark.Success => - stageData.completedIndices.add(info.index) - stageData.numCompleteTasks += 1 - None - case kill: TaskKilled => - stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( - kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - Some(kill.toErrorString) - case commitDenied: TaskCommitDenied => - stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( - commitDenied.toErrorString, stageData.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - Some(commitDenied.toErrorString) - case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates - stageData.numFailedTasks += 1 - Some(e.toErrorString) - case e: TaskFailedReason => // All other failure cases - stageData.numFailedTasks += 1 - Some(e.toErrorString) - } - - val taskMetrics = Option(taskEnd.taskMetrics) - taskMetrics.foreach { m => - val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics) - updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) - } - - val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info)) - taskData.updateTaskInfo(info) - taskData.updateTaskMetrics(taskMetrics) - taskData.errorMessage = errorMessage - - // If Tasks is too large, remove and garbage collect old tasks - if (stageData.taskData.size > retainedTasks) { - stageData.taskData = stageData.taskData.drop( - calculateNumberToRemove(stageData.taskData.size, retainedTasks)) - } - - for ( - activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); - jobId <- activeJobsDependentOnStage; - jobData <- jobIdToData.get(jobId) - ) { - jobData.numActiveTasks -= 1 - taskEnd.reason match { - case Success => - jobData.completedIndices.add((taskEnd.stageId, info.index)) - jobData.numCompletedTasks += 1 - case kill: TaskKilled => - jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( - kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) - case commitDenied: TaskCommitDenied => - jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( - commitDenied.toErrorString, jobData.reasonToNumKilled.getOrElse( - commitDenied.toErrorString, 0) + 1) - case _ => - jobData.numFailedTasks += 1 - } - } - } - } - - /** - * Remove at least (maxRetained / 10) items to reduce friction. - */ - private def calculateNumberToRemove(dataSize: Int, retainedSize: Int): Int = { - math.max(retainedSize / 10, dataSize - retainedSize) - } - - /** - * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage - * aggregate metrics by calculating deltas between the currently recorded metrics and the new - * metrics. - */ - def updateAggregateMetrics( - stageData: StageUIData, - execId: String, - taskMetrics: TaskMetrics, - oldMetrics: Option[TaskMetricsUIData]) { - val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) - - val shuffleWriteDelta = - taskMetrics.shuffleWriteMetrics.bytesWritten - - oldMetrics.map(_.shuffleWriteMetrics.bytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWriteDelta - execSummary.shuffleWrite += shuffleWriteDelta - - val shuffleWriteRecordsDelta = - taskMetrics.shuffleWriteMetrics.recordsWritten - - oldMetrics.map(_.shuffleWriteMetrics.recordsWritten).getOrElse(0L) - stageData.shuffleWriteRecords += shuffleWriteRecordsDelta - execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta - - val shuffleReadDelta = - taskMetrics.shuffleReadMetrics.totalBytesRead - - oldMetrics.map(_.shuffleReadMetrics.totalBytesRead).getOrElse(0L) - stageData.shuffleReadTotalBytes += shuffleReadDelta - execSummary.shuffleRead += shuffleReadDelta - - val shuffleReadRecordsDelta = - taskMetrics.shuffleReadMetrics.recordsRead - - oldMetrics.map(_.shuffleReadMetrics.recordsRead).getOrElse(0L) - stageData.shuffleReadRecords += shuffleReadRecordsDelta - execSummary.shuffleReadRecords += shuffleReadRecordsDelta - - val inputBytesDelta = - taskMetrics.inputMetrics.bytesRead - - oldMetrics.map(_.inputMetrics.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytesDelta - execSummary.inputBytes += inputBytesDelta - - val inputRecordsDelta = - taskMetrics.inputMetrics.recordsRead - - oldMetrics.map(_.inputMetrics.recordsRead).getOrElse(0L) - stageData.inputRecords += inputRecordsDelta - execSummary.inputRecords += inputRecordsDelta - - val outputBytesDelta = - taskMetrics.outputMetrics.bytesWritten - - oldMetrics.map(_.outputMetrics.bytesWritten).getOrElse(0L) - stageData.outputBytes += outputBytesDelta - execSummary.outputBytes += outputBytesDelta - - val outputRecordsDelta = - taskMetrics.outputMetrics.recordsWritten - - oldMetrics.map(_.outputMetrics.recordsWritten).getOrElse(0L) - stageData.outputRecords += outputRecordsDelta - execSummary.outputRecords += outputRecordsDelta - - val diskSpillDelta = - taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskSpillDelta - execSummary.diskBytesSpilled += diskSpillDelta - - val memorySpillDelta = - taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memorySpillDelta - execSummary.memoryBytesSpilled += memorySpillDelta - - val timeDelta = - taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += timeDelta - - val cpuTimeDelta = - taskMetrics.executorCpuTime - oldMetrics.map(_.executorCpuTime).getOrElse(0L) - stageData.executorCpuTime += cpuTimeDelta - } - - override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) { - val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { - logWarning("Metrics update for task in unknown stage " + sid) - new StageUIData - }) - val taskData = stageData.taskData.get(taskId) - val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) - taskData.foreach { t => - if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) - // Overwrite task metrics - t.updateTaskMetrics(Some(metrics)) - } - } - } - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - schedulingMode = environmentUpdate - .environmentDetails("Spark Properties").toMap - .get("spark.scheduler.mode") - .map(SchedulingMode.withName) - } - } - - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { - synchronized { - val blockManagerId = blockManagerAdded.blockManagerId - val executorId = blockManagerId.executorId - executorIdToBlockManagerId(executorId) = blockManagerId - } - } - - override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { - synchronized { - val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToBlockManagerId.remove(executorId) - } - } - - override def onApplicationStart(appStarted: SparkListenerApplicationStart) { - startTime = appStarted.time - } - - override def onApplicationEnd(appEnded: SparkListenerApplicationEnd) { - endTime = appEnded.time - } - - /** - * For testing only. Wait until at least `numExecutors` executors are up, or throw - * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. - * Exposed for testing. - * - * @param numExecutors the number of executors to wait at least - * @param timeout time to wait in milliseconds - */ - private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { - val finishTime = System.currentTimeMillis() + timeout - while (System.currentTimeMillis() < finishTime) { - val numBlockManagers = synchronized { - blockManagerIds.size - } - if (numBlockManagers >= numExecutors + 1) { - // Need to count the block manager in driver - return - } - // Sleep rather than using wait/notify, because this is used only for testing and wait/notify - // add overhead in the general case. - Thread.sleep(10) - } - throw new TimeoutException( - s"Can't find $numExecutors executors before $timeout milliseconds elapsed") - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5f93f2ffb412f..11a6a34344976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala deleted file mode 100644 index 5acec0d0f54c9..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import scala.collection.mutable -import scala.collection.mutable.{HashMap, LinkedHashMap} - -import com.google.common.collect.Interners - -import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor._ -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.util.AccumulatorContext -import org.apache.spark.util.collection.OpenHashSet - -private[spark] object UIData { - - class ExecutorSummary { - var taskTime : Long = 0 - var failedTasks : Int = 0 - var succeededTasks : Int = 0 - var reasonToNumKilled : Map[String, Int] = Map.empty - var inputBytes : Long = 0 - var inputRecords : Long = 0 - var outputBytes : Long = 0 - var outputRecords : Long = 0 - var shuffleRead : Long = 0 - var shuffleReadRecords : Long = 0 - var shuffleWrite : Long = 0 - var shuffleWriteRecords : Long = 0 - var memoryBytesSpilled : Long = 0 - var diskBytesSpilled : Long = 0 - var isBlacklisted : Int = 0 - } - - class JobUIData( - var jobId: Int = -1, - var submissionTime: Option[Long] = None, - var completionTime: Option[Long] = None, - var stageIds: Seq[Int] = Seq.empty, - var jobGroup: Option[String] = None, - var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, - /* Tasks */ - // `numTasks` is a potential underestimate of the true number of tasks that this job will run. - // This may be an underestimate because the job start event references all of the result - // stages' transitive stage dependencies, but some of these stages might be skipped if their - // output is available from earlier runs. - // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. - var numTasks: Int = 0, - var numActiveTasks: Int = 0, - var numCompletedTasks: Int = 0, - var completedIndices: OpenHashSet[(Int, Int)] = new OpenHashSet[(Int, Int)](), - var numSkippedTasks: Int = 0, - var numFailedTasks: Int = 0, - var reasonToNumKilled: Map[String, Int] = Map.empty, - /* Stages */ - var numActiveStages: Int = 0, - // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: - var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](), - var numSkippedStages: Int = 0, - var numFailedStages: Int = 0 - ) - - class StageUIData { - var numActiveTasks: Int = _ - var numCompleteTasks: Int = _ - var completedIndices = new OpenHashSet[Int]() - var numFailedTasks: Int = _ - var reasonToNumKilled: Map[String, Int] = Map.empty - - var executorRunTime: Long = _ - var executorCpuTime: Long = _ - - var inputBytes: Long = _ - var inputRecords: Long = _ - var outputBytes: Long = _ - var outputRecords: Long = _ - var shuffleReadTotalBytes: Long = _ - var shuffleReadRecords : Long = _ - var shuffleWriteBytes: Long = _ - var shuffleWriteRecords: Long = _ - var memoryBytesSpilled: Long = _ - var diskBytesSpilled: Long = _ - var isBlacklisted: Int = _ - var lastUpdateTime: Option[Long] = None - - var schedulingPool: String = "" - var description: Option[String] = None - - var accumulables = new HashMap[Long, AccumulableInfo] - var taskData = new LinkedHashMap[Long, TaskUIData] - var executorSummary = new HashMap[String, ExecutorSummary] - - def hasInput: Boolean = inputBytes > 0 - def hasOutput: Boolean = outputBytes > 0 - def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 - def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 - def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0 - } - - /** - * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. - */ - class TaskUIData private(private var _taskInfo: TaskInfo) { - - private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY) - - var errorMessage: Option[String] = None - - def taskInfo: TaskInfo = _taskInfo - - def metrics: Option[TaskMetricsUIData] = _metrics - - def updateTaskInfo(taskInfo: TaskInfo): Unit = { - _taskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) - } - - def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { - _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics) - } - - def taskDuration(lastUpdateTime: Option[Long] = None): Option[Long] = { - if (taskInfo.status == "RUNNING") { - Some(_taskInfo.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis))) - } else { - _metrics.map(_.executorRunTime) - } - } - } - - object TaskUIData { - - private val stringInterner = Interners.newWeakInterner[String]() - - /** String interning to reduce the memory usage. */ - private def weakIntern(s: String): String = { - stringInterner.intern(s) - } - - def apply(taskInfo: TaskInfo): TaskUIData = { - new TaskUIData(dropInternalAndSQLAccumulables(taskInfo)) - } - - /** - * We don't need to store internal or SQL accumulables as their values will be shown in other - * places, so drop them to reduce the memory usage. - */ - private[spark] def dropInternalAndSQLAccumulables(taskInfo: TaskInfo): TaskInfo = { - val newTaskInfo = new TaskInfo( - taskId = taskInfo.taskId, - index = taskInfo.index, - attemptNumber = taskInfo.attemptNumber, - launchTime = taskInfo.launchTime, - executorId = weakIntern(taskInfo.executorId), - host = weakIntern(taskInfo.host), - taskLocality = taskInfo.taskLocality, - speculative = taskInfo.speculative - ) - newTaskInfo.gettingResultTime = taskInfo.gettingResultTime - newTaskInfo.setAccumulables(taskInfo.accumulables.filter { - accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) - }) - newTaskInfo.finishTime = taskInfo.finishTime - newTaskInfo.failed = taskInfo.failed - newTaskInfo.killed = taskInfo.killed - newTaskInfo - } - } - - case class TaskMetricsUIData( - executorDeserializeTime: Long, - executorDeserializeCpuTime: Long, - executorRunTime: Long, - executorCpuTime: Long, - resultSize: Long, - jvmGCTime: Long, - resultSerializationTime: Long, - memoryBytesSpilled: Long, - diskBytesSpilled: Long, - peakExecutionMemory: Long, - inputMetrics: InputMetricsUIData, - outputMetrics: OutputMetricsUIData, - shuffleReadMetrics: ShuffleReadMetricsUIData, - shuffleWriteMetrics: ShuffleWriteMetricsUIData) - - object TaskMetricsUIData { - def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = { - TaskMetricsUIData( - executorDeserializeTime = m.executorDeserializeTime, - executorDeserializeCpuTime = m.executorDeserializeCpuTime, - executorRunTime = m.executorRunTime, - executorCpuTime = m.executorCpuTime, - resultSize = m.resultSize, - jvmGCTime = m.jvmGCTime, - resultSerializationTime = m.resultSerializationTime, - memoryBytesSpilled = m.memoryBytesSpilled, - diskBytesSpilled = m.diskBytesSpilled, - peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics), - outputMetrics = OutputMetricsUIData(m.outputMetrics), - shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), - shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) - } - - val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty) - } - - case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) - object InputMetricsUIData { - def apply(metrics: InputMetrics): InputMetricsUIData = { - if (metrics.bytesRead == 0 && metrics.recordsRead == 0) { - EMPTY - } else { - new InputMetricsUIData( - bytesRead = metrics.bytesRead, - recordsRead = metrics.recordsRead) - } - } - private val EMPTY = InputMetricsUIData(0, 0) - } - - case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) - object OutputMetricsUIData { - def apply(metrics: OutputMetrics): OutputMetricsUIData = { - if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) { - EMPTY - } else { - new OutputMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten) - } - } - private val EMPTY = OutputMetricsUIData(0, 0) - } - - case class ShuffleReadMetricsUIData( - remoteBlocksFetched: Long, - localBlocksFetched: Long, - remoteBytesRead: Long, - remoteBytesReadToDisk: Long, - localBytesRead: Long, - fetchWaitTime: Long, - recordsRead: Long, - totalBytesRead: Long, - totalBlocksFetched: Long) - - object ShuffleReadMetricsUIData { - def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { - if ( - metrics.remoteBlocksFetched == 0 && - metrics.localBlocksFetched == 0 && - metrics.remoteBytesRead == 0 && - metrics.localBytesRead == 0 && - metrics.fetchWaitTime == 0 && - metrics.recordsRead == 0 && - metrics.totalBytesRead == 0 && - metrics.totalBlocksFetched == 0) { - EMPTY - } else { - new ShuffleReadMetricsUIData( - remoteBlocksFetched = metrics.remoteBlocksFetched, - localBlocksFetched = metrics.localBlocksFetched, - remoteBytesRead = metrics.remoteBytesRead, - remoteBytesReadToDisk = metrics.remoteBytesReadToDisk, - localBytesRead = metrics.localBytesRead, - fetchWaitTime = metrics.fetchWaitTime, - recordsRead = metrics.recordsRead, - totalBytesRead = metrics.totalBytesRead, - totalBlocksFetched = metrics.totalBlocksFetched - ) - } - } - private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0, 0) - } - - case class ShuffleWriteMetricsUIData( - bytesWritten: Long, - recordsWritten: Long, - writeTime: Long) - - object ShuffleWriteMetricsUIData { - def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { - if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) { - EMPTY - } else { - new ShuffleWriteMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten, - writeTime = metrics.writeTime - ) - } - } - private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0) - } - -} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index ea9f6d2fc20f4..e09d5f59817b9 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -156,7 +156,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) - sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) + TestUtils.waitUntilExecutorsUp(sc, 2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index fe944031bc948..472952addf353 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -66,7 +66,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 5483f2b8434aa..a15ae040d43a9 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -44,13 +44,13 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont stageIds.size should be(2) val firstStageInfo = eventually(timeout(10 seconds)) { - sc.statusTracker.getStageInfo(stageIds(0)).get + sc.statusTracker.getStageInfo(stageIds.min).get } - firstStageInfo.stageId() should be(stageIds(0)) + firstStageInfo.stageId() should be(stageIds.min) firstStageInfo.currentAttemptId() should be(0) firstStageInfo.numTasks() should be(2) eventually(timeout(10 seconds)) { - val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get + val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds.min).get updatedFirstStageInfo.numCompletedTasks() should be(2) updatedFirstStageInfo.numActiveTasks() should be(0) updatedFirstStageInfo.numFailedTasks() should be(0) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 46f9ac6b0273a..159629825c677 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -224,7 +224,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) + TestUtils.waitUntilExecutorsUp(_sc, numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d395e09969453..feefb6a4d73f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2406,13 +2406,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // OutputCommitCoordinator requires the task info itself to not be null. private def createFakeTaskInfo(): TaskInfo = { val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false) - info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info.finishTime = 1 info } private def createFakeTaskInfoWithId(taskId: Long): TaskInfo = { val info = new TaskInfo(taskId, 0, 0, 0L, "", "", TaskLocality.ANY, false) - info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info.finishTime = 1 info } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 9fa8859382911..123f7f49d21b5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.scheduler.cluster.ExecutorInfo /** @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala deleted file mode 100644 index 48be3be81755a..0000000000000 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ /dev/null @@ -1,442 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import java.util.Properties - -import org.scalatest.Matchers - -import org.apache.spark._ -import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor._ -import org.apache.spark.scheduler._ -import org.apache.spark.ui.jobs.UIData.TaskUIData -import org.apache.spark.util.{AccumulatorContext, Utils} - -class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { - - val jobSubmissionTime = 1421191042750L - val jobCompletionTime = 1421191296660L - - private def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - SparkListenerStageSubmitted(stageInfo) - } - - private def createStageEndEvent(stageId: Int, failed: Boolean = false) = { - val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - if (failed) { - stageInfo.failureReason = Some("Failed!") - } - SparkListenerStageCompleted(stageInfo) - } - - private def createJobStartEvent( - jobId: Int, - stageIds: Seq[Int], - jobGroup: Option[String] = None): SparkListenerJobStart = { - val stageInfos = stageIds.map { stageId => - new StageInfo(stageId, 0, stageId.toString, 0, null, null, "") - } - val properties: Option[Properties] = jobGroup.map { groupId => - val props = new Properties() - props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) - props - } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) - } - - private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { - val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded - SparkListenerJobEnd(jobId, jobCompletionTime, result) - } - - private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { - val stagesThatWontBeRun = jobId * 200 to jobId * 200 + 10 - val stageIds = jobId * 100 to jobId * 100 + 50 - listener.onJobStart(createJobStartEvent(jobId, stageIds ++ stagesThatWontBeRun)) - for (stageId <- stageIds) { - listener.onStageSubmitted(createStageStartEvent(stageId)) - listener.onStageCompleted(createStageEndEvent(stageId, failed = stageId % 2 == 0)) - } - listener.onJobEnd(createJobEndEvent(jobId, shouldFail)) - } - - private def assertActiveJobsStateIsEmpty(listener: JobProgressListener) { - listener.getSizesOfActiveStateTrackingCollections.foreach { case (fieldName, size) => - assert(size === 0, s"$fieldName was not empty") - } - } - - test("test LRU eviction of stages") { - def runWithListener(listener: JobProgressListener) : Unit = { - for (i <- 1 to 50) { - listener.onStageSubmitted(createStageStartEvent(i)) - listener.onStageCompleted(createStageEndEvent(i)) - } - assertActiveJobsStateIsEmpty(listener) - } - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - var listener = new JobProgressListener(conf) - - // Test with 5 retainedStages - runWithListener(listener) - listener.completedStages.size should be (5) - listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) - - // Test with 0 retainedStages - conf.set("spark.ui.retainedStages", 0.toString) - listener = new JobProgressListener(conf) - runWithListener(listener) - listener.completedStages.size should be (0) - } - - test("test clearing of stageIdToActiveJobs") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - val listener = new JobProgressListener(conf) - val jobId = 0 - val stageIds = 1 to 50 - // Start a job with 50 stages - listener.onJobStart(createJobStartEvent(jobId, stageIds)) - for (stageId <- stageIds) { - listener.onStageSubmitted(createStageStartEvent(stageId)) - } - listener.stageIdToActiveJobIds.size should be > 0 - - // Complete the stages and job - for (stageId <- stageIds) { - listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) - } - listener.onJobEnd(createJobEndEvent(jobId, false)) - assertActiveJobsStateIsEmpty(listener) - listener.stageIdToActiveJobIds.size should be (0) - } - - test("test clearing of jobGroupToJobIds") { - def runWithListener(listener: JobProgressListener): Unit = { - // Run 50 jobs, each with one stage - for (jobId <- 0 to 50) { - listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) - listener.onStageSubmitted(createStageStartEvent(0)) - listener.onStageCompleted(createStageEndEvent(0, failed = false)) - listener.onJobEnd(createJobEndEvent(jobId, false)) - } - assertActiveJobsStateIsEmpty(listener) - } - val conf = new SparkConf() - conf.set("spark.ui.retainedJobs", 5.toString) - - var listener = new JobProgressListener(conf) - runWithListener(listener) - // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs - listener.jobGroupToJobIds.size should be (5) - - // Test with 0 jobs - conf.set("spark.ui.retainedJobs", 0.toString) - listener = new JobProgressListener(conf) - runWithListener(listener) - listener.jobGroupToJobIds.size should be (0) - } - - test("test LRU eviction of jobs") { - val conf = new SparkConf() - conf.set("spark.ui.retainedStages", 5.toString) - conf.set("spark.ui.retainedJobs", 5.toString) - val listener = new JobProgressListener(conf) - - // Run a bunch of jobs to get the listener into a state where we've exceeded both the - // job and stage retention limits: - for (jobId <- 1 to 10) { - runJob(listener, jobId, shouldFail = false) - } - for (jobId <- 200 to 210) { - runJob(listener, jobId, shouldFail = true) - } - assertActiveJobsStateIsEmpty(listener) - // Snapshot the sizes of various soft- and hard-size-limited collections: - val softLimitSizes = listener.getSizesOfSoftSizeLimitedCollections - val hardLimitSizes = listener.getSizesOfHardSizeLimitedCollections - // Run some more jobs: - for (jobId <- 11 to 50) { - runJob(listener, jobId, shouldFail = false) - // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: - listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) - listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) - } - - listener.completedJobs.size should be (5) - listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) - - for (jobId <- 51 to 100) { - runJob(listener, jobId, shouldFail = true) - // We shouldn't exceed the hard / soft limit sizes after the jobs have finished: - listener.getSizesOfSoftSizeLimitedCollections should be (softLimitSizes) - listener.getSizesOfHardSizeLimitedCollections should be (hardLimitSizes) - } - assertActiveJobsStateIsEmpty(listener) - - // Completed and failed jobs each their own size limits, so this should still be the same: - listener.completedJobs.size should be (5) - listener.completedJobs.map(_.jobId).toSet should be (Set(50, 49, 48, 47, 46)) - listener.failedJobs.size should be (5) - listener.failedJobs.map(_.jobId).toSet should be (Set(100, 99, 98, 97, 96)) - } - - test("test executor id to summary") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - val taskMetrics = TaskMetrics.empty - val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() - assert(listener.stageIdToData.size === 0) - - // finish this task, should get updated shuffleRead - shuffleReadMetrics.incRemoteBytesRead(1000) - taskMetrics.mergeShuffleReadMetrics() - var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 1000) - - // finish a task with unknown executor-id, nothing should happen - taskInfo = - new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.size === 1) - - // finish this task, should get updated duration - taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 2000) - - // finish this task, should get updated duration - taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - task = new ShuffleMapTask(0) - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToData.getOrElse((0, 0), fail()) - .executorSummary.getOrElse("exe-2", fail()).shuffleRead === 1000) - } - - test("test task success vs failure counting for different task end reasons") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - val metrics = TaskMetrics.empty - val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - - // Go through all the failure cases to make sure we are counting them as failures. - val taskFailedReasons = Seq( - Resubmitted, - new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), - TaskResultLost, - ExecutorLostFailure("0", true, Some("Induced failure")), - UnknownReason) - var failCount = 0 - for (reason <- taskFailedReasons) { - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, reason, taskInfo, metrics)) - failCount += 1 - assert(listener.stageIdToData((task.stageId, 0)).numCompleteTasks === 0) - assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) - } - - // Make sure killed tasks are accounted for correctly. - listener.onTaskEnd( - SparkListenerTaskEnd( - task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) - - // Make sure we count success as success. - listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 1)).numCompleteTasks === 1) - assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) - } - - test("test update metrics") { - val conf = new SparkConf() - val listener = new JobProgressListener(conf) - - val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) - val execId = "exe-1" - - def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.registered - val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() - val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics - val inputMetrics = taskMetrics.inputMetrics - val outputMetrics = taskMetrics.outputMetrics - shuffleReadMetrics.incRemoteBytesRead(base + 1) - shuffleReadMetrics.incLocalBytesRead(base + 9) - shuffleReadMetrics.incRemoteBlocksFetched(base + 2) - taskMetrics.mergeShuffleReadMetrics() - shuffleWriteMetrics.incBytesWritten(base + 3) - taskMetrics.setExecutorRunTime(base + 4) - taskMetrics.incDiskBytesSpilled(base + 5) - taskMetrics.incMemoryBytesSpilled(base + 6) - inputMetrics.setBytesRead(base + 7) - outputMetrics.setBytesWritten(base + 8) - taskMetrics - } - - def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { - val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, - false) - taskInfo.finishTime = finishTime - taskInfo - } - - listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1234L))) - listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1235L))) - listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L))) - listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) - - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, 0, makeTaskMetrics(0).accumulators().map(AccumulatorSuite.makeInfo)), - (1235L, 0, 0, makeTaskMetrics(100).accumulators().map(AccumulatorSuite.makeInfo)), - (1236L, 1, 0, makeTaskMetrics(200).accumulators().map(AccumulatorSuite.makeInfo))))) - - var stage0Data = listener.stageIdToData.get((0, 0)).get - var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadTotalBytes == 220) - assert(stage1Data.shuffleReadTotalBytes == 410) - assert(stage0Data.shuffleWriteBytes == 106) - assert(stage1Data.shuffleWriteBytes == 203) - assert(stage0Data.executorRunTime == 108) - assert(stage1Data.executorRunTime == 204) - assert(stage0Data.diskBytesSpilled == 110) - assert(stage1Data.diskBytesSpilled == 205) - assert(stage0Data.memoryBytesSpilled == 112) - assert(stage1Data.memoryBytesSpilled == 206) - assert(stage0Data.inputBytes == 114) - assert(stage1Data.inputBytes == 207) - assert(stage0Data.outputBytes == 116) - assert(stage1Data.outputBytes == 208) - - assert( - stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 2) - assert( - stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 102) - assert( - stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 202) - - // task that was included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1), - makeTaskMetrics(300))) - // task that wasn't included in a heartbeat - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1), - makeTaskMetrics(400))) - - stage0Data = listener.stageIdToData.get((0, 0)).get - stage1Data = listener.stageIdToData.get((1, 0)).get - // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed - // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. - assert(stage0Data.shuffleReadTotalBytes == 820) - // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. - assert(stage1Data.shuffleReadTotalBytes == 1220) - assert(stage0Data.shuffleWriteBytes == 406) - assert(stage1Data.shuffleWriteBytes == 606) - assert(stage0Data.executorRunTime == 408) - assert(stage1Data.executorRunTime == 608) - assert(stage0Data.diskBytesSpilled == 410) - assert(stage1Data.diskBytesSpilled == 610) - assert(stage0Data.memoryBytesSpilled == 412) - assert(stage1Data.memoryBytesSpilled == 612) - assert(stage0Data.inputBytes == 414) - assert(stage1Data.inputBytes == 614) - assert(stage0Data.outputBytes == 416) - assert(stage1Data.outputBytes == 616) - assert( - stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 302) - assert( - stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 402) - } - - test("drop internal and sql accumulators") { - val taskInfo = new TaskInfo(0, 0, 0, 0, "", "", TaskLocality.ANY, false) - val internalAccum = - AccumulableInfo(id = 1, name = Some("internal"), None, None, true, false, None) - val sqlAccum = AccumulableInfo( - id = 2, - name = Some("sql"), - update = None, - value = None, - internal = false, - countFailedValues = false, - metadata = Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) - val userAccum = AccumulableInfo( - id = 3, - name = Some("user"), - update = None, - value = None, - internal = false, - countFailedValues = false, - metadata = None) - taskInfo.setAccumulables(List(internalAccum, sqlAccum, userAccum)) - - val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) - assert(newTaskInfo.accumulables === Seq(userAccum)) - } - - test("SPARK-19146 drop more elements when stageData.taskData.size > retainedTasks") { - val conf = new SparkConf() - conf.set("spark.ui.retainedTasks", "100") - val taskMetrics = TaskMetrics.empty - taskMetrics.mergeShuffleReadMetrics() - val task = new ShuffleMapTask(0) - val taskType = Utils.getFormattedClassName(task) - - val listener1 = new JobProgressListener(conf) - for (t <- 1 to 101) { - val taskInfo = new TaskInfo(t, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - listener1.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - } - // 101 - math.max(100 / 10, 101 - 100) = 91 - assert(listener1.stageIdToData((task.stageId, task.stageAttemptId)).taskData.size === 91) - - val listener2 = new JobProgressListener(conf) - for (t <- 1 to 150) { - val taskInfo = new TaskInfo(t, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) - taskInfo.finishTime = 1 - listener2.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics)) - } - // 150 - math.max(100 / 10, 150 - 100) = 100 - assert(listener2.stageIdToData((task.stageId, task.stageAttemptId)).taskData.size === 100) - } - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 915c7e2e2fda3..5b8dcd0338cce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,6 +45,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkStatusTracker.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.jobs.JobProgressListener"), // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 69e15655ad790..6748dd4ec48e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -23,16 +23,16 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.status.api.v1.{JobData, StageData} import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener.SparkJobId import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} -import org.apache.spark.ui.jobs.UIData.JobUIData -private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobData: Option[JobData]) private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private val streamingListener = parent.listener - private val sparkListener = parent.ssc.sc.jobProgressListener + private val store = parent.parent.store private def columns: Seq[Node] = { Output Op Id @@ -52,13 +52,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, - sparkJob: SparkJobIdWithUIData): Seq[Node] = { - if (sparkJob.jobUIData.isDefined) { + jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { + if (jobIdWithData.jobData.isDefined) { generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, - numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) + numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, - numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) + numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.sparkJobId) } } @@ -94,15 +94,15 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, - sparkJob: JobUIData): Seq[Node] = { + sparkJob: JobData): Seq[Node] = { val duration: Option[Long] = { sparkJob.submissionTime.map { start => - val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) - end - start + val end = sparkJob.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) + end - start.getTime() } } val lastFailureReason = - sparkJob.stageIds.sorted.reverse.flatMap(sparkListener.stageIdToInfo.get). + sparkJob.stageIds.sorted.reverse.flatMap(getStageData). dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") @@ -135,7 +135,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { {formattedDuration} - {sparkJob.completedStageIndices.size}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} + {sparkJob.numCompletedStages}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} {if (sparkJob.numFailedStages > 0) s"(${sparkJob.numFailedStages} failed)"} {if (sparkJob.numSkippedStages > 0) s"(${sparkJob.numSkippedStages} skipped)"} @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - reasonToNumKilled = sparkJob.reasonToNumKilled, + reasonToNumKilled = sparkJob.killedTasksSummary, total = sparkJob.numTasks - sparkJob.numSkippedTasks) } @@ -246,11 +246,19 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {

    } - private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { - sparkListener.activeJobs.get(sparkJobId).orElse { - sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { - sparkListener.failedJobs.find(_.jobId == sparkJobId) - } + private def getJobData(sparkJobId: SparkJobId): Option[JobData] = { + try { + Some(store.job(sparkJobId)) + } catch { + case _: NoSuchElementException => None + } + } + + private def getStageData(stageId: Int): Option[StageData] = { + try { + Some(store.lastStageAttempt(stageId)) + } catch { + case _: NoSuchElementException => None } } @@ -282,25 +290,22 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) (outputOperation, sparkJobIds) }.toSeq.sortBy(_._1.id) - sparkListener.synchronized { - val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => - (outputOpData, - sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) - } + val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => + (outputOpData, sparkJobIds.map { jobId => SparkJobIdWithUIData(jobId, getJobData(jobId)) }) + } - - - {columns} - - - { - outputOpWithJobs.map { case (outputOpData, sparkJobIds) => - generateOutputOpIdRow(outputOpData, sparkJobIds) - } +
    + + {columns} + + + { + outputOpWithJobs.map { case (outputOpData, sparkJobs) => + generateOutputOpIdRow(outputOpData, sparkJobs) } - -
    - } + } + + } def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { From ab6f60c4d6417cbb0240216a6b492aadcca3043e Mon Sep 17 00:00:00 2001 From: Jakub Dubovsky Date: Thu, 30 Nov 2017 10:24:30 +0900 Subject: [PATCH 317/712] [SPARK-22585][CORE] Path in addJar is not url encoded ## What changes were proposed in this pull request? This updates a behavior of `addJar` method of `sparkContext` class. If path without any scheme is passed as input it is used literally without url encoding/decoding it. ## How was this patch tested? A unit test is added for this. Author: Jakub Dubovsky Closes #19834 from james64/SPARK-22585-encode-add-jar. --- .../main/scala/org/apache/spark/SparkContext.scala | 6 +++++- .../scala/org/apache/spark/SparkContextSuite.scala | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 984dd0a6629a2..c174939ca2e54 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1837,7 +1837,11 @@ class SparkContext(config: SparkConf) extends Logging { Utils.validateURL(uri) uri.getScheme match { // A JAR file which exists only on the driver node - case null | "file" => addJarFile(new File(uri.getPath)) + case null => + // SPARK-22585 path without schema is not url encoded + addJarFile(new File(uri.getRawPath)) + // A JAR file which exists only on the driver node + case "file" => addJarFile(new File(uri.getPath)) // A JAR file which exists locally on every worker node case "local" => "file:" + uri.getPath case _ => path diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 0ed5f26863dad..2bde8757dae5d 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -309,6 +309,17 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.listJars().head.contains(tmpJar.getName)) } + test("SPARK-22585 addJar argument without scheme is interpreted literally without url decoding") { + val tmpDir = new File(Utils.createTempDir(), "host%3A443") + tmpDir.mkdirs() + val tmpJar = File.createTempFile("t%2F", ".jar", tmpDir) + + sc = new SparkContext("local", "test") + + sc.addJar(tmpJar.getAbsolutePath) + assert(sc.listJars().size === 1) + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) From 92cfbeeb5ce9e2c618a76b3fe60ce84b9d38605b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 30 Nov 2017 10:26:55 +0900 Subject: [PATCH 318/712] [SPARK-21866][ML][PYTHON][FOLLOWUP] Few cleanups and fix image test failure in Python 3.6.0 / NumPy 1.13.3 ## What changes were proposed in this pull request? Image test seems failed in Python 3.6.0 / NumPy 1.13.3. I manually tested as below: ``` ====================================================================== ERROR: test_read_images (pyspark.ml.tests.ImageReaderTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/ml/tests.py", line 1831, in test_read_images self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) File "/.../spark/python/pyspark/ml/image.py", line 149, in toImage data = bytearray(array.astype(dtype=np.uint8).ravel()) TypeError: only integer scalar arrays can be converted to a scalar index ---------------------------------------------------------------------- Ran 1 test in 7.606s ``` To be clear, I think the error seems from NumPy - https://github.com/numpy/numpy/blob/75b2d5d427afdb1392f2a0b2092e0767e4bab53d/numpy/core/src/multiarray/number.c#L947 For a smaller scope: ```python >>> import numpy as np >>> bytearray(np.array([1]).astype(dtype=np.uint8)) Traceback (most recent call last): File "", line 1, in TypeError: only integer scalar arrays can be converted to a scalar index ``` In Python 2.7 / NumPy 1.13.1, it prints: ``` bytearray(b'\x01') ``` So, here, I simply worked around it by converting it to bytes as below: ```python >>> bytearray(np.array([1]).astype(dtype=np.uint8).tobytes()) bytearray(b'\x01') ``` Also, while looking into it again, I realised few arguments could be quite confusing, for example, `Row` that needs some specific attributes and `numpy.ndarray`. I added few type checking and added some tests accordingly. So, it shows an error message as below: ``` TypeError: array argument should be numpy.ndarray; however, it got []. ``` ## How was this patch tested? Manually tested with `./python/run-tests`. And also: ``` PYSPARK_PYTHON=python3 SPARK_TESTING=1 bin/pyspark pyspark.ml.tests ImageReaderTest ``` Author: hyukjinkwon Closes #19835 from HyukjinKwon/SPARK-21866-followup. --- python/pyspark/ml/image.py | 27 ++++++++++++++++++++++++--- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 7d14f05295572..2b61aa9c0d9e9 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -108,12 +108,23 @@ def toNDArray(self, image): """ Converts an image to an array with metadata. - :param image: The image to be converted. + :param `Row` image: A row that contains the image to be converted. It should + have the attributes specified in `ImageSchema.imageSchema`. :return: a `numpy.ndarray` that is an image. .. versionadded:: 2.3.0 """ + if not isinstance(image, Row): + raise TypeError( + "image argument should be pyspark.sql.types.Row; however, " + "it got [%s]." % type(image)) + + if any(not hasattr(image, f) for f in self.imageFields): + raise ValueError( + "image argument should have attributes specified in " + "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) + height = image.height width = image.width nChannels = image.nChannels @@ -127,15 +138,20 @@ def toImage(self, array, origin=""): """ Converts an array with metadata to a two-dimensional image. - :param array array: The array to convert to image. + :param `numpy.ndarray` array: The array to convert to image. :param str origin: Path to the image, optional. :return: a :class:`Row` that is a two dimensional image. .. versionadded:: 2.3.0 """ + if not isinstance(array, np.ndarray): + raise TypeError( + "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) + if array.ndim != 3: raise ValueError("Invalid array shape") + height, width, nChannels = array.shape ocvTypes = ImageSchema.ocvTypes if nChannels == 1: @@ -146,7 +162,12 @@ def toImage(self, array, origin=""): mode = ocvTypes["CV_8UC4"] else: raise ValueError("Invalid number of channels") - data = bytearray(array.astype(dtype=np.uint8).ravel()) + + # Running `bytearray(numpy.array([1]))` fails in specific Python versions + # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. + # Here, it avoids it by converting it to bytes. + data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + # Creating new Row with _create_row(), because Row(name = value, ... ) # orders fields by name, which conflicts with expected schema order # when the new DataFrame is created by UDF diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2258d61c95333..89ef555cf3442 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -71,7 +71,7 @@ from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase ser = PickleSerializer() @@ -1836,6 +1836,24 @@ def test_read_images(self): self.assertEqual(ImageSchema.imageFields, expected) self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "image argument should be pyspark.sql.types.Row; however", + lambda: ImageSchema.toNDArray("a")) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + ValueError, + "image argument should have attributes specified in", + lambda: ImageSchema.toNDArray(Row(a=1))) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "array argument should be numpy.ndarray; however, it got", + lambda: ImageSchema.toImage("a")) + class ALSTest(SparkSessionTestCase): From 444a2bbb67c2548d121152bc922b4c3337ddc8e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Nov 2017 18:28:58 +0800 Subject: [PATCH 319/712] [SPARK-22652][SQL] remove set methods in ColumnarRow ## What changes were proposed in this pull request? As a step to make `ColumnVector` public, the `ColumnarRow` returned by `ColumnVector#getStruct` should be immutable. However we do need the mutability of `ColumnaRow` for the fast vectorized hashmap in hash aggregate. To solve this, this PR introduces a `MutableColumnarRow` for this use case. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #19847 from cloud-fan/mutable-row. --- .../sql/execution/vectorized/ColumnarRow.java | 102 +------ .../vectorized/MutableColumnarRow.java | 278 ++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 3 +- .../VectorizedHashMapGenerator.scala | 82 +++--- .../vectorized/ColumnVectorSuite.scala | 12 + .../vectorized/ColumnarBatchSuite.scala | 23 -- 6 files changed, 336 insertions(+), 164 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index 98a907322713b..cabb7479525d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; - import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -32,17 +30,10 @@ public final class ColumnarRow extends InternalRow { protected int rowId; private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; // Ctor used if this is a struct. ColumnarRow(ColumnVector[] columns) { this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } } public ColumnVector[] columns() { return columns; } @@ -205,97 +196,8 @@ public Object get(int ordinal, DataType dataType) { } @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } - - @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java new file mode 100644 index 0000000000000..f272cc163611b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash + * aggregate. + * + * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to + * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. + */ +public final class MutableColumnarRow extends InternalRow { + public int rowId; + private final WritableColumnVector[] columns; + + public MutableColumnarRow(WritableColumnVector[] columns) { + this.columns = columns; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale()); + setDecimal(ordinal, d, t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index dc8aecf185a96..913978892cd8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -894,7 +895,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; + | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null; """.stripMargin } else { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index fd783d905b776..44ba539ebf7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -76,10 +77,9 @@ class VectorizedHashMapGenerator( }.mkString("\n").concat(";") s""" - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors; - | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; - | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; + | private ${classOf[OnHeapColumnVector].getName}[] vectors; + | private ${classOf[ColumnarBatch].getName} batch; + | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; | private int capacity = 1 << 16; | private double loadFactor = 0.5; @@ -91,19 +91,16 @@ class VectorizedHashMapGenerator( | $generatedAggBufferSchema | | public $generatedClassName() { - | batchVectors = org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector.allocateColumns(capacity, schema); - | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | schema, batchVectors, capacity); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); | - | bufferVectors = new org.apache.spark.sql.execution.vectorized - | .OnHeapColumnVector[aggregateBufferSchema.fields().length]; + | // Generates a projection to return the aggregate buffer only. + | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = + | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { - | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}]; + | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } - | // TODO: Possibly generate this projection in HashAggregate directly - | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( - | aggregateBufferSchema, bufferVectors, capacity); + | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); @@ -114,13 +111,13 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns true if the group-by keys exist at a given index in the - * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we - * have 2 long group-by keys, the generated function would be of the form: + * associated [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, + * if we have 2 long group-by keys, the generated function would be of the form: * * {{{ * private boolean equals(int idx, long agg_key, long agg_key1) { - * return batchVectors[0].getLong(buckets[idx]) == agg_key && - * batchVectors[1].getLong(buckets[idx]) == agg_key1; + * return vectors[0].getLong(buckets[idx]) == agg_key && + * vectors[1].getLong(buckets[idx]) == agg_key1; * } * }}} */ @@ -128,7 +125,7 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]", + s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", key.dataType), key.name)})""" }.mkString(" && ") } @@ -141,29 +138,35 @@ class VectorizedHashMapGenerator( } /** - * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the + * Generates a method that returns a + * [[org.apache.spark.sql.execution.vectorized.MutableColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * [[org.apache.spark.sql.execution.vectorized.OnHeapColumnVector]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( - * long agg_key, long agg_key1) { + * public MutableColumnarRow findOrInsert(long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; * int idx = (int) h & (numBuckets - 1); * while (step < maxSteps) { * // Return bucket index if it's either an empty slot or already contains the key * if (buckets[idx] == -1) { - * batchVectors[0].putLong(numRows, agg_key); - * batchVectors[1].putLong(numRows, agg_key1); - * batchVectors[2].putLong(numRows, 0); - * buckets[idx] = numRows++; - * return batch.getRow(buckets[idx]); + * if (numRows < capacity) { + * vectors[0].putLong(numRows, agg_key); + * vectors[1].putLong(numRows, agg_key1); + * vectors[2].putLong(numRows, 0); + * buckets[idx] = numRows++; + * aggBufferRow.rowId = numRows; + * return aggBufferRow; + * } else { + * // No more space + * return null; + * } * } else if (equals(idx, agg_key, agg_key1)) { - * return batch.getRow(buckets[idx]); + * aggBufferRow.rowId = buckets[idx]; + * return aggBufferRow; * } * idx = (idx + 1) & (numBuckets - 1); * step++; @@ -177,20 +180,19 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name) + ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, + ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, buffVars(ordinal), nullable = true) } } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ - groupingKeySignature}) { + |public ${classOf[MutableColumnarRow].getName} findOrInsert($groupingKeySignature) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); @@ -208,15 +210,15 @@ class VectorizedHashMapGenerator( | ${genCodeToSetAggBuffers(bufferValues).mkString("\n")} | | buckets[idx] = numRows++; - | batch.setNumRows(numRows); - | aggregateBufferBatch.setNumRows(numRows); - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } else { | // No more space | return null; | } | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) { - | return aggregateBufferBatch.getRow(buckets[idx]); + | aggBufferRow.rowId = buckets[idx]; + | return aggBufferRow; | } | idx = (idx + 1) & (numBuckets - 1); | step++; @@ -229,8 +231,8 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator - | rowIterator() { + |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + | batch.setNumRows(numRows); | return batch.rowIterator(); |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 3c76ca79f5dda..e28ab710f5a99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -163,6 +163,18 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } + testVectors("mutable ColumnarRow", 10, IntegerType) { testVector => + val mutableRow = new MutableColumnarRow(Array(testVector)) + (0 until 10).foreach { i => + mutableRow.rowId = i + mutableRow.setInt(0, 10 - i) + } + (0 until 10).foreach { i => + mutableRow.rowId = i + assert(mutableRow.getInt(0) === (10 - i)) + } + } + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) testVectors("array", 10, arrayType) { testVector => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80a50866aa504..1b4e2bad09a20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1129,29 +1129,6 @@ class ColumnarBatchSuite extends SparkFunSuite { testRandomRows(false, 30) } - test("mutable ColumnarBatch rows") { - val NUM_ITERS = 10 - val types = Array( - BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType, - DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, - DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), - new DecimalType(12, 2), new DecimalType(30, 10)) - for (i <- 0 to NUM_ITERS) { - val random = new Random(System.nanoTime()) - val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) - val oldRow = RandomDataGenerator.randomRow(random, schema) - val newRow = RandomDataGenerator.randomRow(random, schema) - - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) - val columnarBatchRow = batch.getRow(0) - newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) - compareStruct(schema, columnarBatchRow, newRow, 0) - batch.close() - } - } - } - test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val column = allocate(1, ByteType, memMode) From 9c29c557635caf739fde942f53255273aac0d7b1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Nov 2017 18:34:38 +0800 Subject: [PATCH 320/712] [SPARK-22643][SQL] ColumnarArray should be an immutable view ## What changes were proposed in this pull request? To make `ColumnVector` public, `ColumnarArray` need to be public too, and we should not have mutable public fields in a public class. This PR proposes to make `ColumnarArray` an immutable view of the data, and always create a new instance of `ColumnarArray` in `ColumnVector#getArray` ## How was this patch tested? new benchmark in `ColumnarBatchBenchmark` Author: Wenchen Fan Closes #19842 from cloud-fan/column-vector. --- .../parquet/VectorizedColumnReader.java | 2 +- .../vectorized/ArrowColumnVector.java | 1 - .../execution/vectorized/ColumnVector.java | 14 +- .../vectorized/ColumnVectorUtils.java | 18 +-- .../execution/vectorized/ColumnarArray.java | 10 +- .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 13 +- .../vectorized/ColumnVectorSuite.scala | 41 +++-- .../vectorized/ColumnarBatchBenchmark.scala | 142 ++++++++++++++---- .../vectorized/ColumnarBatchSuite.scala | 18 +-- 11 files changed, 164 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 0f1f470dc597e..71ca8b1b96a98 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -425,7 +425,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.isArray()) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { for (int i = 0; i < num; i++) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 5c502c9d91be4..0071bd66760be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -315,7 +315,6 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new ColumnarArray(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; accessor = new StructAccessor(mapVector); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 940457f2e3363..cca14911fbb28 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -175,9 +175,7 @@ public ColumnarRow getStruct(int rowId, int size) { * Returns the array at rowid. */ public final ColumnarArray getArray(int rowId) { - resultArray.length = getArrayLength(rowId); - resultArray.offset = getArrayOffset(rowId); - return resultArray; + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } /** @@ -213,21 +211,11 @@ public MapData getMap(int ordinal) { */ public abstract ColumnVector getChildColumn(int ordinal); - /** - * Returns true if this column is an array. - */ - public final boolean isArray() { return resultArray != null; } - /** * Data type for this column. */ protected DataType type; - /** - * Reusable Array holder for getArray(). - */ - protected ColumnarArray resultArray; - /** * Reusable Struct holder for getStruct(). */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b4b5f0a265934..bc62bc43484e5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -98,21 +98,13 @@ public static void populate(WritableColumnVector col, InternalRow row, int field * For example, an array of IntegerType will return an int[]. * Throws exceptions for unhandled schemas. */ - public static Object toPrimitiveJavaArray(ColumnarArray array) { - DataType dt = array.data.dataType(); - if (dt instanceof IntegerType) { - int[] result = new int[array.length]; - ColumnVector data = array.data; - for (int i = 0; i < result.length; i++) { - if (data.isNullAt(array.offset + i)) { - throw new RuntimeException("Cannot handle NULL values."); - } - result[i] = data.getInt(array.offset + i); + public static int[] toJavaIntArray(ColumnarArray array) { + for (int i = 0; i < array.numElements(); i++) { + if (array.isNullAt(i)) { + throw new RuntimeException("Cannot handle NULL values."); } - return result; - } else { - throw new UnsupportedOperationException(); } + return array.toIntArray(); } private static void appendValue(WritableColumnVector dst, DataType t, Object o) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java index b9da641fc66c8..cbc39d1d0aec2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -29,12 +29,14 @@ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). - public final ColumnVector data; - public int length; - public int offset; + private final ColumnVector data; + private final int offset; + private final int length; - ColumnarArray(ColumnVector data) { + ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; + this.offset = offset; + this.length = length; } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1cbaf08569334..806d0291a6c49 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -532,7 +532,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (this.resultArray != null) { + if (isArray()) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 85d72295ab9b8..6e7f74ce12f16 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -505,7 +505,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { + if (isArray()) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index e7653f0c00b9a..0bea4cc97142d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -75,7 +75,6 @@ public void close() { } dictionary = null; resultStruct = null; - resultArray = null; } public void reserve(int requiredCapacity) { @@ -650,6 +649,11 @@ public WritableColumnVector getDictionaryIds() { */ protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); + protected boolean isArray() { + return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType || + DecimalType.isByteArrayDecimalType(type); + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. @@ -658,8 +662,7 @@ protected WritableColumnVector(int capacity, DataType type) { super(type); this.capacity = capacity; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType - || DecimalType.isByteArrayDecimalType(type)) { + if (isArray()) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -670,7 +673,6 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultArray = new ColumnarArray(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; @@ -678,18 +680,15 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - this.resultArray = null; this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); - this.resultArray = null; this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; - this.resultArray = null; this.resultStruct = null; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index e28ab710f5a99..54b31cee031f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -57,7 +57,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendBoolean(i % 2 == 0) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, BooleanType) === (i % 2 == 0)) @@ -69,7 +69,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByte(i.toByte) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, ByteType) === i.toByte) @@ -81,7 +81,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendShort(i.toShort) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, ShortType) === i.toShort) @@ -93,7 +93,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendInt(i) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, IntegerType) === i) @@ -105,7 +105,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendLong(i) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, LongType) === i) @@ -117,7 +117,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendFloat(i.toFloat) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, FloatType) === i.toFloat) @@ -129,7 +129,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendDouble(i.toDouble) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, DoubleType) === i.toDouble) @@ -142,7 +142,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) @@ -155,7 +155,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnarArray(testVector) + val array = new ColumnarArray(testVector, 0, 10) (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") @@ -191,12 +191,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.putArray(2, 3, 0) testVector.putArray(3, 3, 3) - val array = new ColumnarArray(testVector) - - assert(array.getArray(0).toIntArray() === Array(0)) - assert(array.getArray(1).toIntArray() === Array(1, 2)) - assert(array.getArray(2).toIntArray() === Array.empty[Int]) - assert(array.getArray(3).toIntArray() === Array(3, 4, 5)) + assert(testVector.getArray(0).toIntArray() === Array(0)) + assert(testVector.getArray(1).toIntArray() === Array(1, 2)) + assert(testVector.getArray(2).toIntArray() === Array.empty[Int]) + assert(testVector.getArray(3).toIntArray() === Array(3, 4, 5)) } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) @@ -208,12 +206,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { c1.putInt(1, 456) c2.putDouble(1, 5.67) - val array = new ColumnarArray(testVector) - - assert(array.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(array.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(array.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(array.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) + assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) + assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { @@ -226,9 +222,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.reserve(16) // Check that none of the values got lost/overwritten. - val array = new ColumnarArray(testVector) (0 until 8).foreach { i => - assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + assert(testVector.getArray(i).toIntArray() === Array(i)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 705b26b8c91e6..38ea2e47fdef8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -23,11 +23,9 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.ColumnVector import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector -import org.apache.spark.sql.execution.vectorized.WritableColumnVector -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -265,20 +263,22 @@ object ColumnarBatchBenchmark { } /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Java Array 248.8 1317.04 1.00 X - ByteBuffer Unsafe 435.6 752.25 0.57 X - ByteBuffer API 1752.0 187.03 0.14 X - DirectByteBuffer 595.4 550.35 0.42 X - Unsafe Buffer 235.2 1393.20 1.06 X - Column(on heap) 189.8 1726.45 1.31 X - Column(off heap) 408.4 802.35 0.61 X - Column(off heap direct) 237.6 1379.12 1.05 X - UnsafeRow (on heap) 414.6 790.35 0.60 X - UnsafeRow (off heap) 487.2 672.58 0.51 X - Column On Heap Append 530.1 618.14 0.59 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Java Array 177 / 181 1856.4 0.5 1.0X + ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X + ByteBuffer API 1411 / 1418 232.2 4.3 0.1X + DirectByteBuffer 467 / 474 701.8 1.4 0.4X + Unsafe Buffer 178 / 185 1843.6 0.5 1.0X + Column(on heap) 178 / 184 1840.8 0.5 1.0X + Column(off heap) 341 / 344 961.8 1.0 0.5X + Column(off heap direct) 178 / 184 1845.4 0.5 1.0X + UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X + UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X + Column On Heap Append 309 / 318 1059.1 0.9 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -332,11 +332,13 @@ object ColumnarBatchBenchmark { } }} /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Boolean Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Bitset 895.88 374.54 1.00 X - Byte Array 578.96 579.56 1.55 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Bitset 726 / 727 462.4 2.2 1.0X + Byte Array 530 / 542 632.7 1.6 1.4X */ benchmark.run() } @@ -387,10 +389,13 @@ object ColumnarBatchBenchmark { } /* - String Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------------- - On Heap 457.0 35.85 1.00 X - Off Heap 1206.0 13.59 0.38 X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + On Heap 332 / 338 49.3 20.3 1.0X + Off Heap 466 / 467 35.2 28.4 0.7X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -398,9 +403,94 @@ object ColumnarBatchBenchmark { benchmark.run } + def arrayAccess(iters: Int): Unit = { + val random = new Random(0) + val count = 4 * 1000 + + val onHeapVector = new OnHeapColumnVector(count, ArrayType(IntegerType)) + val offHeapVector = new OffHeapColumnVector(count, ArrayType(IntegerType)) + + val minSize = 3 + val maxSize = 32 + var arraysCount = 0 + var elementsCount = 0 + while (arraysCount < count) { + val size = random.nextInt(maxSize - minSize) + minSize + val onHeapArrayData = onHeapVector.arrayData() + val offHeapArrayData = offHeapVector.arrayData() + + var i = 0 + while (i < size) { + val value = random.nextInt() + onHeapArrayData.appendInt(value) + offHeapArrayData.appendInt(value) + i += 1 + } + + onHeapVector.putArray(arraysCount, elementsCount, size) + offHeapVector.putArray(arraysCount, elementsCount, size) + elementsCount += size + arraysCount += 1 + } + + def readArrays(onHeap: Boolean): Unit = { + System.gc() + val vector = if (onHeap) onHeapVector else offHeapVector + + var sum = 0L + for (_ <- 0 until iters) { + var i = 0 + while (i < count) { + sum += vector.getArray(i).numElements() + i += 1 + } + } + } + + def readArrayElements(onHeap: Boolean): Unit = { + System.gc() + val vector = if (onHeap) onHeapVector else offHeapVector + + var sum = 0L + for (_ <- 0 until iters) { + var i = 0 + while (i < count) { + val array = vector.getArray(i) + val size = array.numElements() + var j = 0 + while (j < size) { + sum += array.getInt(j) + j += 1 + } + i += 1 + } + } + } + + val benchmark = new Benchmark("Array Vector Read", count * iters) + benchmark.addCase("On Heap Read Size Only") { _ => readArrays(true) } + benchmark.addCase("Off Heap Read Size Only") { _ => readArrays(false) } + benchmark.addCase("On Heap Read Elements") { _ => readArrayElements(true) } + benchmark.addCase("Off Heap Read Elements") { _ => readArrayElements(false) } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + On Heap Read Size Only 415 / 422 394.7 2.5 1.0X + Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X + On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X + Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + */ + benchmark.run + } + def main(args: Array[String]): Unit = { intAccess(1024 * 40) booleanAccess(1024 * 40) stringAccess(1024 * 4) + arrayAccess(1024 * 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1b4e2bad09a20..0ae4f2d117609 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -645,26 +645,26 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) + val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -677,7 +677,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } From 6eb203fae7bbc9940710da40f314b89ffb4dd324 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 1 Dec 2017 01:21:52 +0900 Subject: [PATCH 321/712] [SPARK-22654][TESTS] Retry Spark tarball download if failed in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Adds a simple loop to retry download of Spark tarballs from different mirrors if the download fails. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19851 from srowen/SPARK-22654. --- .../HiveExternalCatalogVersionsSuite.scala | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6859432c406a9..a3d5b941a6761 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.io.File import java.nio.file.Files +import scala.sys.process._ + import org.apache.spark.TestUtils import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -50,14 +52,24 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { super.afterAll() } - private def downloadSpark(version: String): Unit = { - import scala.sys.process._ + private def tryDownloadSpark(version: String, path: String): Unit = { + // Try mirrors a few times until one succeeds + for (i <- 0 until 3) { + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + logInfo(s"Downloading Spark $version from $url") + if (Seq("wget", url, "-q", "-P", path).! == 0) { + return + } + logWarning(s"Failed to download Spark $version from $url") + } + fail(s"Unable to download Spark $version") + } - val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" - Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! + private def downloadSpark(version: String): Unit = { + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath From 932bd09c80dc2dc113e94f59f4dcb77e77de7c58 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 1 Dec 2017 01:24:15 +0900 Subject: [PATCH 322/712] [SPARK-22635][SQL][ORC] FileNotFoundException while reading ORC files containing special characters ## What changes were proposed in this pull request? SPARK-22146 fix the FileNotFoundException issue only for the `inferSchema` method, ie. only for the schema inference, but it doesn't fix the problem when actually reading the data. Thus nearly the same exception happens when someone tries to use the data. This PR covers fixing the problem also there. ## How was this patch tested? enhanced UT Author: Marco Gaido Closes #19844 from mgaido91/SPARK-22635. --- .../org/apache/spark/sql/hive/orc/OrcFileFormat.scala | 11 +++++------ .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3b33a9ff082f3..95741c7b30289 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -133,10 +133,12 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value + val filePath = new Path(new URI(file.filePath)) + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty if (isEmptyFile) { Iterator.empty } else { @@ -146,15 +148,12 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val job = Job.getInstance(conf) FileInputFormat.setInputPaths(job, file.filePath) - val fileSplit = new FileSplit( - new Path(new URI(file.filePath)), file.start, file.length, Array.empty - ) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) // Custom OrcRecordReader is used to get // ObjectInspector during recordReader creation itself and can // avoid NameNode call in unwrapOrcStructs per file. // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + val orcReader = OrcFile.createReader(filePath, OrcFile.readerOptions(conf)) new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index a1060476f2211..c8caba83bf365 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1350,7 +1350,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTempDir { dir => val tmpFile = s"$dir/$nameWithSpecialChars" spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - spark.read.format(format).load(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) } } } From 999ec137a97844abbbd483dd98c7ded2f8ff356c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Dec 2017 02:28:24 +0800 Subject: [PATCH 323/712] [SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code ## What changes were proposed in this pull request? This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit. This PR reduces # of generated global variables in the following three operations: * `Cast` with String to primitive byte/short/int/long * `RegExpReplace` * `CreateArray` I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class. ## How was this patch tested? Added test cases Author: Kazuaki Ishizaki Closes #19797 from kiszk/SPARK-22570. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 ++++++------- .../expressions/complexTypeCreator.scala | 36 +++++++++++-------- .../expressions/regexpExpressions.scala | 8 ++--- .../sql/catalyst/expressions/CastSuite.scala | 8 +++++ .../expressions/RegexpExpressionsSuite.scala | 11 +++++- .../optimizer/complexTypesSuite.scala | 7 ++++ 6 files changed, 61 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8cafaef61c7d1..f4ecbdb8393ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" @@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" @@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" @@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.LongWrapper", wrapper, - s"$wrapper = new UTF8String.LongWrapper();") + val wrapper = ctx.freshName("longWrapper") (c, evPrim, evNull) => s""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 57a7f2e207738..fc68bf478e1c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + ctx.splitExpressions(assigns) + postprocess, + code = preprocess + assigns + postprocess, value = arrayData, isNull = "false") } @@ -77,24 +77,22 @@ private [sql] object GenArrayData { * * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements - * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array * @param isMapKey if true, throw an exception when the element is null - * @return (code pre-assignments, assignments to each array elements, code post-assignments, - * arrayData name) + * @return (code pre-assignments, concatenated assignments to each array elements, + * code post-assignments, arrayData name) */ def genCodeToCreateArrayData( ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - isMapKey: Boolean): (String, Seq[String], String, String) = { - val arrayName = ctx.freshName("array") + isMapKey: Boolean): (String, String, String, String) = { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length if (!ctx.isPrimitiveType(elementType)) { + val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -110,17 +108,21 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("Object[]", arrayDataName) :: Nil) - ("", - assignments, + (s"Object[] $arrayName = new Object[$numElements];", + assignmentString, s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", arrayDataName) } else { + val arrayName = ctx.freshName("array") val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -137,14 +139,18 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $arrayDataName = new UnsafeArrayData(); + UnsafeArrayData $arrayDataName = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, - assignments, + assignmentString, "", arrayDataName) } @@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { s""" final boolean ${ev.isNull} = false; $preprocessKeyData - ${ctx.splitExpressions(assignKeys)} + $assignKeys $postprocessKeyData $preprocessValueData - ${ctx.splitExpressions(assignValues)} + $assignValues $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index d0d663f63f5db..53d7096dd87d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termLastReplacement = ctx.freshName("lastReplacement") val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") + val termResult = ctx.freshName("termResult") val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName @@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -355,7 +352,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ${termLastReplacementInUTF8} = $rep.clone(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } - ${termResult}.delete(0, ${termResult}.length()); + $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); while (${matcher}.find()) { @@ -363,6 +360,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } ${matcher}.appendTail(${termResult}); ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${termResult} = null; $setEvNotNull """ }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 7837d6529d12b..65617be05a434 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner)) checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter) } + + test("SPARK-22570: Cast should not create a lot of global variables") { + val ctx = new CodegenContext + cast("1", IntegerType).genCode(ctx) + cast("2", LongType).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 1ce150e091981..4fa61fbaf66c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StringType /** * Unit tests for regular expression (regexp) related SQL expressions. @@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(nonNullExpr, "num-num", row1) } + test("SPARK-22570: RegExpReplace should not create a lot of global variables") { + val ctx = new CodegenContext + RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) + // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) + // are always required + assert(ctx.mutableStates.length == 4) + } + test("RegexExtract") { val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 3634accf1ec21..e3675367d78e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{ comparePlans(Optimizer execute query, expected) } + test("SPARK-22570: CreateArray should not create a lot of global variables") { + val ctx = new CodegenContext + CreateArray(Seq(Literal(1))).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } + test("simplify map ops") { val rel = relation .select( From 6ac57fd0d1c82b834eb4bf0dd57596b92a99d6de Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Thu, 30 Nov 2017 14:25:10 -0800 Subject: [PATCH 324/712] [SPARK-21417][SQL] Infer join conditions using propagated constraints ## What changes were proposed in this pull request? This PR adds an optimization rule that infers join conditions using propagated constraints. For instance, if there is a join, where the left relation has 'a = 1' and the right relation has 'b = 1', then the rule infers 'a = b' as a join predicate. Only semantically new predicates are appended to the existing join condition. Refer to the corresponding ticket and tests for more details. ## How was this patch tested? This patch comes with a new test suite to cover the implemented logic. Author: aokolnychyi Closes #18692 from aokolnychyi/spark-21417. --- .../expressions/EquivalentExpressionMap.scala | 66 +++++ .../catalyst/expressions/ExpressionSet.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../spark/sql/catalyst/optimizer/joins.scala | 60 +++++ .../EquivalentExpressionMapSuite.scala | 56 +++++ .../optimizer/EliminateCrossJoinSuite.scala | 238 ++++++++++++++++++ 6 files changed, 423 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala new file mode 100644 index 0000000000000..cf1614afb1a76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr + +/** + * A class that allows you to map an expression into a set of equivalent expressions. The keys are + * handled based on their semantic meaning and ignoring cosmetic differences. The values are + * represented as [[ExpressionSet]]s. + * + * The underlying representation of keys depends on the [[Expression.semanticHash]] and + * [[Expression.semanticEquals]] methods. + * + * {{{ + * val map = new EquivalentExpressionMap() + * + * map.put(1 + 2, a) + * map.put(rand(), b) + * + * map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(rand()) => Set() // non-deterministic expressions are not equivalent + * }}} + */ +class EquivalentExpressionMap { + + private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, ExpressionSet] + + def put(expression: Expression, equivalentExpression: Expression): Unit = { + val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, ExpressionSet.empty) + equivalenceMap(expression) = equivalentExpressions + equivalentExpression + } + + def get(expression: Expression): Set[Expression] = + equivalenceMap.getOrElse(expression, ExpressionSet.empty) +} + +object EquivalentExpressionMap { + + private implicit class SemanticallyEqualExpr(val expr: Expression) { + override def equals(obj: Any): Boolean = obj match { + case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr) + case _ => false + } + + override def hashCode: Int = expr.semanticHash() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 7e8e7b8cd5f18..e9890837af07d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -27,6 +27,8 @@ object ExpressionSet { expressions.foreach(set.add) set } + + val empty: ExpressionSet = ExpressionSet(Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d961bf2e6e5e..8a5c486912abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -87,6 +87,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PushProjectionThroughUnion, ReorderJoin, EliminateOuterJoin, + EliminateCrossJoin, InferFiltersFromConstraints, BooleanSimplification, PushPredicateThroughJoin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf273fd6f..29a3a7f109b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins @@ -152,3 +153,62 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. + * + * The optimization is applicable only to CROSS joins. For other join types, adding inferred join + * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN + * node's requirements which otherwise could have. + * + * For instance, given a CROSS join with the constraint 'a = 1' from the left child and the + * constraint 'b = 1' from the right child, this rule infers a new join predicate 'a = b' and + * converts it to an Inner join. + */ +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + eliminateCrossJoin(plan) + } else { + plan + } + } + + private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(leftPlan, rightPlan, Cross, None) => + val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) + val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) + val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) + val joinConditionOpt = inferredJoinPredicates.reduceOption(And) + if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join + } + + private def inferJoinPredicates( + leftConstraints: Set[Expression], + rightConstraints: Set[Expression]): mutable.Set[EqualTo] = { + + val equivalentExpressionMap = new EquivalentExpressionMap() + + leftConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + equivalentExpressionMap.put(expr, attr) + case EqualTo(expr: Expression, attr: Attribute) => + equivalentExpressionMap.put(expr, attr) + case _ => + } + + val joinConditions = mutable.Set.empty[EqualTo] + + rightConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case EqualTo(expr: Expression, attr: Attribute) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case _ => + } + + joinConditions + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala new file mode 100644 index 0000000000000..bad7e17bb6cf2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class EquivalentExpressionMapSuite extends SparkFunSuite { + + private val onePlusTwo = Literal(1) + Literal(2) + private val twoPlusOne = Literal(2) + Literal(1) + private val rand = Rand(10) + + test("behaviour of the equivalent expression map") { + val equivalentExpressionMap = new EquivalentExpressionMap() + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(Literal(1) + Literal(3), 'b) + equivalentExpressionMap.put(rand, 'c) + + // 1 + 2 should be equivalent to 2 + 1 + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + // non-deterministic expressions should not be equivalent + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + + // if the same (key, value) is added several times, the map still returns only one entry + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(twoPlusOne, 'a) + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + + // get several equivalent attributes + equivalentExpressionMap.put(onePlusTwo, 'e) + assertResult(ExpressionSet(Seq('a, 'e)))(equivalentExpressionMap.get(onePlusTwo)) + assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size) + + // several non-deterministic expressions should not be equivalent + equivalentExpressionMap.put(rand, 'd) + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + assertResult(0)(equivalentExpressionMap.get(rand).size) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala new file mode 100644 index 0000000000000..e04dd28ee36a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Not, Rand} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.types.IntegerType + +class EliminateCrossJoinSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate cross joins", FixedPoint(10), + EliminateCrossJoin, + PushPredicateThroughJoin) :: Nil + } + + val testRelation1 = LocalRelation('a.int, 'b.int) + val testRelation2 = LocalRelation('c.int, 'd.int) + + test("successful elimination of cross joins (1)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c && 'a === 'd)) + } + + test("successful elimination of cross joins (2)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'b === 2 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && 'b === 2, + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (3)") { + // PushPredicateThroughJoin will push 'd === 'a into the join condition + // EliminateCrossJoin will NOT apply because the condition will be already present + // therefore, the join type will stay the same (i.e., CROSS) + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = Literal(1) === 'd, + expectedJoinType = Cross, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (4)") { + // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically equal + checkJoinOptimization( + originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Literal(1) * Literal(2), + expectedRightRelationFilter = Literal(2) * Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (5)") { + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (6)") { + checkJoinOptimization( + originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", IntegerType) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Cast("1", IntegerType), + expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (7)") { + // The join condition appears due to PushPredicateThroughJoin + checkJoinOptimization( + originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'b === 10, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10)) + } + + test("successful elimination of cross joins (8)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("inability to detect join conditions when constant propagation is disabled") { + withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + } + + test("inability to detect join conditions (1)") { + checkJoinOptimization( + originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a >= 1, + expectedRightRelationFilter = 'c === 1 && 'd >= 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (2)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1), + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1 || 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (3)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = Some('c === 'b), + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some('c === 'b)) + } + + test("inability to detect join conditions (4)") { + checkJoinOptimization( + originalFilter = Not('a === 1) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Not('a === 1), + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (5)") { + checkJoinOptimization( + originalFilter = 'a === Rand(10) && 'b === 1 && 'd === Rand(10) && 'c === 3, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = Some('a === Rand(10) && 'd === Rand(10)), + expectedLeftRelationFilter = 'b === 1, + expectedRightRelationFilter = 'c === 3, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + private def checkJoinOptimization( + originalFilter: Expression, + originalJoinType: JoinType, + originalJoinCondition: Option[Expression], + expectedFilter: Option[Expression], + expectedLeftRelationFilter: Expression, + expectedRightRelationFilter: Expression, + expectedJoinType: JoinType, + expectedJoinCondition: Option[Expression]): Unit = { + + val originalQuery = testRelation1 + .join(testRelation2, originalJoinType, originalJoinCondition) + .where(originalFilter) + val optimizedQuery = Optimize.execute(originalQuery.analyze) + + val left = testRelation1.where(expectedLeftRelationFilter) + val right = testRelation2.where(expectedRightRelationFilter) + val join = left.join(right, expectedJoinType, expectedJoinCondition) + val expectedQuery = expectedFilter.fold(join)(join.where(_)).analyze + + comparePlans(optimizedQuery, expectedQuery) + } +} From bcceab649510a45f4c4b8e44b157c9987adff6f4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 30 Nov 2017 15:36:26 -0800 Subject: [PATCH 325/712] [SPARK-22489][SQL] Shouldn't change broadcast join buildSide if user clearly specified ## What changes were proposed in this pull request? How to reproduce: ```scala import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("table1") spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value").createTempView("table2") val bl = sql("SELECT /*+ MAPJOIN(t1) */ * FROM table1 t1 JOIN table2 t2 ON t1.key = t2.key").queryExecution.executedPlan println(bl.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide) ``` The result is `BuildRight`, but should be `BuildLeft`. This PR fix this issue. ## How was this patch tested? unit tests Author: Yuming Wang Closes #19714 from wangyum/SPARK-22489. --- docs/sql-programming-guide.md | 58 ++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 67 +++++++++++++----- .../execution/joins/BroadcastJoinSuite.scala | 69 ++++++++++++++++++- 3 files changed, 177 insertions(+), 17 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 983770d506836..a1b9c3bbfd059 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1492,6 +1492,64 @@ that these options will be deprecated in future release as more optimizations ar +## Broadcast Hint for SQL Queries + +The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. +When both sides of a join are specified, Spark broadcasts the one having the lower statistics. +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +support BHJ. When the broadcast nested loop join is selected, we still respect the hint. + +
    + +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +
    + +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.broadcast; +broadcast(spark.table("src")).join(spark.table("records"), "key").show(); +{% endhighlight %} + +
    + +
    + +{% highlight python %} +from pyspark.sql.functions import broadcast +broadcast(spark.table("src")).join(spark.table("records"), "key").show() +{% endhighlight %} + +
    + +
    + +{% highlight r %} +src <- sql("SELECT * FROM src") +records <- sql("SELECT * FROM records") +head(join(broadcast(src), records, src$key == records$key)) +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +-- We accept BROADCAST, BROADCASTJOIN and MAPJOIN for broadcast hint +SELECT /*+ BROADCAST(r) */ * FROM records r JOIN src s ON r.key = s.key +{% endhighlight %} + +
    +
    + # Distributed SQL Engine Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 19b858faba6ea..1fe3cb1c8750a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery @@ -91,12 +91,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * predicates can be evaluated by matching join keys. If found, Join implementations are chosen * with the following precedence: * - * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold - * or if that side has an explicit broadcast hint (e.g. the user applied the - * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side - * of the join will be broadcasted and the other side will be streamed, with no shuffling - * performed. If both sides of the join are eligible to be broadcasted then the + * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the + * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). + * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller + * estimated physical size. If neither one of the sides has the broadcast hint, + * we only broadcast the join side if its estimated physical size that is smaller than + * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. * - Sort merge: if the matching join keys are sortable. @@ -112,9 +112,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.hints.broadcast || - (plan.stats.sizeInBytes >= 0 && - plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold } /** @@ -149,10 +147,45 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } + private def broadcastSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): BuildSide = { + + def smallerSide = + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + + val buildRight = canBuildRight && right.stats.hints.broadcast + val buildLeft = canBuildLeft && left.stats.hints.broadcast + + if (buildRight && buildLeft) { + // Broadcast smaller side base on its estimated physical size + // if both sides have broadcast hint + smallerSide + } else if (buildRight) { + BuildRight + } else if (buildLeft) { + BuildLeft + } else if (canBuildRight && canBuildLeft) { + // for the last default broadcast nested loop join + smallerSide + } else { + throw new AnalysisException("Can not decide which side to broadcast for this join") + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + Seq(joins.BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( @@ -189,6 +222,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcasted + case j @ logical.Join(left, right, joinType, condition) + if (canBuildRight(joinType) && right.stats.hints.broadcast) || + (canBuildLeft(joinType) && left.stats.hints.broadcast) => + val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case j @ logical.Join(left, right, joinType, condition) if canBuildRight(joinType) && canBroadcast(right) => joins.BroadcastNestedLoopJoinExec( @@ -203,12 +243,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { - BuildRight - } else { - BuildLeft - } + val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a0fad862b44c7..67e2cdc7394bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -223,4 +223,71 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) } + + test("Shouldn't change broadcast join buildSide if user clearly specified") { + def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } + } + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + val bh = BroadcastHashJoinExec.toString + val bl = BroadcastNestedLoopJoinExec.toString + + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) + + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // INNER JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) + // FULL JOIN && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight) + // RIGHT JOIN => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2", bl, BuildLeft) + // INNER JOIN && broadcast(t2) => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2", bl, BuildRight) + // FULL OUTER && broadcast(t1) => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + // FULL OUTER && broadcast(t2) => BuildRight + assertJoinBuildSide( + "SELECT /*+ MAPJOIN(t2) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildRight) + // FULL OUTER && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + } + } + } } From f5f8e84d9d35751dad51490b6ae22931aa88db7b Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Thu, 30 Nov 2017 15:41:34 -0800 Subject: [PATCH 326/712] [SPARK-22614] Dataset API: repartitionByRange(...) ## What changes were proposed in this pull request? This PR introduces a way to explicitly range-partition a Dataset. So far, only round-robin and hash partitioning were possible via `df.repartition(...)`, but sometimes range partitioning might be desirable: e.g. when writing to disk, for better compression without the cost of global sort. The current implementation piggybacks on the existing `RepartitionByExpression` `LogicalPlan` and simply adds the following logic: If its expressions are of type `SortOrder`, then it will do `RangePartitioning`; otherwise `HashPartitioning`. This was by far the least intrusive solution I could come up with. ## How was this patch tested? Unit test for `RepartitionByExpression` changes, a test to ensure we're not changing the behavior of existing `.repartition()` and a few end-to-end tests in `DataFrameSuite`. Author: Adrian Ionescu Closes #19828 from adrian-ionescu/repartitionByRange. --- .../plans/logical/basicLogicalOperators.scala | 20 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 26 +++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 57 +++++++++++++++++-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 57 +++++++++++++++++++ 5 files changed, 157 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c2750c3079814..93de7c1daf5c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -838,6 +839,25 @@ case class RepartitionByExpression( require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") + val partitioning: Partitioning = { + val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder]) + + require(sortOrder.isEmpty || nonSortOrder.isEmpty, + s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " + + "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " + + "means `HashPartitioning`. In this case we have:" + + s""" + |SortOrder: ${sortOrder} + |NonSortOrder: ${nonSortOrder} + """.stripMargin) + + if (sortOrder.nonEmpty) { + RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) + } else { + HashPartitioning(nonSortOrder, numPartitions) + } + } + override def maxRows: Option[Long] = child.maxRows override def shuffle: Boolean = true } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e56a5d6368318..0e2e706a31a05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.types._ @@ -514,4 +515,29 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq("Number of column aliases does not match number of columns. " + "Number of column aliases: 5; number of columns: 4.")) } + + test("SPARK-22614 RepartitionByExpression partitioning") { + def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning + assert(partitioning.isInstanceOf[T]) + } + + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) + checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 'a.attr, 'b.attr) + + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder(Literal(10), Ascending)) + checkPartitioning[RangePartitioning](numPartitions = 10, + exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending)) + + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 0, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = -1, exprs = Literal(20)) + } + intercept[IllegalArgumentException] { + checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, Ascending), 'b.attr) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1620ab3aa2094..167c9d050c3c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2732,8 +2732,18 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. + // However, we don't want to complicate the semantics of this API method. + // Instead, let's give users a friendly error message, pointing them to the new method. + val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder]) + if (sortOrders.nonEmpty) throw new IllegalArgumentException( + s"""Invalid partitionExprs specified: $sortOrders + |For range partitioning use repartitionByRange(...) instead. + """.stripMargin) + withTypedPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + } } /** @@ -2747,9 +2757,46 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + def repartition(partitionExprs: Column*): Dataset[T] = { + repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") + val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { + case expr: SortOrder => expr + case expr: Expression => SortOrder(expr, Ascending) + }) + withTypedPlan { + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) + } + } + + /** + * Returns a new Dataset partitioned by the given partitioning expressions, using + * `spark.sql.shuffle.partitions` as number of partitions. + * The resulting Dataset is range partitioned. + * + * At least one partition-by expression must be specified. + * When no explicit sort order is specified, "ascending nulls first" is assumed. + * + * @group typedrel + * @since 2.3.0 + */ + @scala.annotation.varargs + def repartitionByRange(partitionExprs: Column*): Dataset[T] = { + repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1fe3cb1c8750a..9e713cd7bbe2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -482,9 +482,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchangeExec(HashPartitioning( - expressions, numPartitions), planLater(child)) :: Nil + case r: logical.RepartitionByExpression => + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 72a5cc98fbec3..5e4c1a6a484fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -358,6 +358,63 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select('key).collect().toSeq) } + test("repartition with SortOrder") { + // passing SortOrder expressions to .repartition() should result in an informative error + + def checkSortOrderErrorMsg[T](data: => Dataset[T]): Unit = { + val ex = intercept[IllegalArgumentException](data) + assert(ex.getMessage.contains("repartitionByRange")) + } + + checkSortOrderErrorMsg { + Seq(0).toDF("a").repartition(2, $"a".asc) + } + + checkSortOrderErrorMsg { + Seq((0, 0)).toDF("a", "b").repartition(2, $"a".asc, $"b") + } + } + + test("repartitionByRange") { + val data1d = Random.shuffle(0.to(9)) + val data2d = data1d.map(i => (i, data1d.size - i)) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".asc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, $"val".desc) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, data1d.size - 1 - i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(42)) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(0, i))) + + checkAnswer( + data1d.toDF("val").repartitionByRange(data1d.size, lit(null), $"val".asc, rand()) + .select(spark_partition_id().as("id"), $"val"), + data1d.map(i => Row(i, i))) + + // .repartitionByRange() assumes .asc by default if no explicit sort order is specified + checkAnswer( + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b") + .select(spark_partition_id().as("id"), $"a", $"b"), + data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b".asc) + .select(spark_partition_id().as("id"), $"a", $"b")) + + // at least one partition-by expression must be specified + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size) + } + intercept[IllegalArgumentException] { + data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*) + } + } + test("coalesce") { intercept[IllegalArgumentException] { testData.select('key).coalesce(0) From 7e5f669eb684629c88218f8ec26c01a41a6fef32 Mon Sep 17 00:00:00 2001 From: gaborgsomogyi Date: Thu, 30 Nov 2017 19:20:32 -0600 Subject: [PATCH 327/712] =?UTF-8?q?[SPARK-22428][DOC]=20Add=20spark=20appl?= =?UTF-8?q?ication=20garbage=20collector=20configurat=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The spark properties for configuring the ContextCleaner are not documented in the official documentation at https://spark.apache.org/docs/latest/configuration.html#available-properties. This PR adds the doc. ## How was this patch tested? Manual. ``` cd docs jekyll build open _site/configuration.html ``` Author: gaborgsomogyi Closes #19826 from gaborgsomogyi/SPARK-22428. --- docs/configuration.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index e42f866c40566..ef061dd39dcba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1132,6 +1132,46 @@ Apart from these, the following properties are also available, and may be useful to get the replication level of the block to the initial number. + + spark.cleaner.periodicGC.interval + 30min + + Controls how often to trigger a garbage collection.

    + This context cleaner triggers cleanups only when weak references are garbage collected. + In long-running applications with large driver JVMs, where there is little memory pressure + on the driver, this may happen very occasionally or not at all. Not cleaning at all may + lead to executors running out of disk space after a while. + + + + spark.cleaner.referenceTracking + true + + Enables or disables context cleaning. + + + + spark.cleaner.referenceTracking.blocking + true + + Controls whether the cleaning thread should block on cleanup tasks (other than shuffle, which is controlled by + spark.cleaner.referenceTracking.blocking.shuffle Spark property). + + + + spark.cleaner.referenceTracking.blocking.shuffle + false + + Controls whether the cleaning thread should block on shuffle cleanup tasks. + + + + spark.cleaner.referenceTracking.cleanCheckpoints + false + + Controls whether to clean checkpoint files if the reference is out of scope. + + ### Execution Behavior From 7da1f5708cc96c18ddb3acd09542621275e71d83 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Thu, 30 Nov 2017 19:24:44 -0600 Subject: [PATCH 328/712] =?UTF-8?q?[SPARK-22373]=20Bump=20Janino=20depende?= =?UTF-8?q?ncy=20version=20to=20fix=20thread=20safety=20issue=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with Janino when compiling generated code. ## What changes were proposed in this pull request? Bump up Janino dependency version to fix thread safety issue during compiling generated code ## How was this patch tested? Check https://issues.apache.org/jira/browse/SPARK-22373 for details. Converted part of the code in CodeGenerator into a standalone application, so the issue can be consistently reproduced locally. Verified that changing Janino dependency version resolved this issue. Author: Min Shen Closes #19839 from Victsm/SPARK-22373. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 50ac6d139bbd4..8f508219c2ded 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.0.jar +commons-compiler-3.0.7.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.0.jar +janino-3.0.7.jar java-xmlbuilder-1.0.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1b1e3166d53db..68e937f50b391 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.0.jar +commons-compiler-3.0.7.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.0.jar +janino-3.0.7.jar java-xmlbuilder-1.0.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/pom.xml b/pom.xml index 7bc66e7d19540..731ee86439eff 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.0 + 3.0.7 2.22.2 2.9.3 3.5.2 From dc365422bb337d19ef39739c7c3cf9e53ec85d09 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 1 Dec 2017 10:53:16 +0800 Subject: [PATCH 329/712] =?UTF-8?q?[SPARK-22653]=20executorAddress=20regis?= =?UTF-8?q?tered=20in=20CoarseGrainedSchedulerBac=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-22653 executorRef.address can be null, pass the executorAddress which accounts for it being null a few lines above the fix. Manually tested this patch. You can reproduce the issue by running a simple spark-shell in yarn client mode with dynamic allocation and request some executors up front. Let those executors idle timeout. Get a heap dump. Without this fix, you will see that addressToExecutorId still contains the ids, with the fix addressToExecutorId is properly cleaned up. Author: Thomas Graves Closes #19850 from tgravescs/SPARK-22653. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 22d9c4cf81c55..7bfb4d53c1834 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -182,7 +182,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val data = new ExecutorData(executorRef, executorRef.address, hostname, + val data = new ExecutorData(executorRef, executorAddress, hostname, cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors From 16adaf634bcca3074b448d95e72177eefdf50069 Mon Sep 17 00:00:00 2001 From: sujith71955 Date: Thu, 30 Nov 2017 20:45:30 -0800 Subject: [PATCH 330/712] [SPARK-22601][SQL] Data load is getting displayed successful on providing non existing nonlocal file path ## What changes were proposed in this pull request? When user tries to load data with a non existing hdfs file path system is not validating it and the load command operation is getting successful. This is misleading to the user. already there is a validation in the scenario of none existing local file path. This PR has added validation in the scenario of nonexisting hdfs file path ## How was this patch tested? UT has been added for verifying the issue, also snapshots has been added after the verification in a spark yarn cluster Author: sujith71955 Closes #19823 from sujith71955/master_LoadComand_Issue. --- .../org/apache/spark/sql/execution/command/tables.scala | 9 ++++++++- .../apache/spark/sql/hive/execution/HiveDDLSuite.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index c9f6e571ddab3..c42e6c3257fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -340,7 +340,7 @@ case class LoadDataCommand( uri } else { val uri = new URI(path) - if (uri.getScheme() != null && uri.getAuthority() != null) { + val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) { uri } else { // Follow Hive's behavior: @@ -380,6 +380,13 @@ case class LoadDataCommand( } new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment()) } + val hadoopConf = sparkSession.sessionState.newHadoopConf() + val srcPath = new Path(hdfsUri) + val fs = srcPath.getFileSystem(hadoopConf) + if (!fs.exists(srcPath)) { + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } + hdfsUri } if (partition.nonEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 9063ef066aa84..6c11905ba8904 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2141,4 +2141,13 @@ class HiveDDLSuite } } } + + test("load command for non local invalid path validation") { + withTable("tbl") { + sql("CREATE TABLE tbl(i INT, j STRING)") + val e = intercept[AnalysisException]( + sql("load data inpath '/doesnotexist.csv' into table tbl")) + assert(e.message.contains("LOAD DATA input path does not exist")) + } + } } From 9d06a9e0cf05af99ba210fabae1e77eccfce7986 Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Fri, 1 Dec 2017 05:14:12 -0600 Subject: [PATCH 331/712] [SPARK-22393][SPARK-SHELL] spark-shell can't find imported types in class constructors, extends clause ## What changes were proposed in this pull request? [SPARK-22393](https://issues.apache.org/jira/browse/SPARK-22393) ## How was this patch tested? With a new test case in `RepSuite` ---- This code is a retrofit of the Scala [SI-9881](https://github.com/scala/bug/issues/9881) bug fix, which never made it into the Scala 2.11 branches. Pushing these changes directly to the Scala repo is not practical (see: https://github.com/scala/scala/pull/6195). Author: Mark Petruska Closes #19846 from mpetruska/SPARK-22393. --- .../apache/spark/repl/SparkExprTyper.scala | 74 +++++++++++++ .../org/apache/spark/repl/SparkILoop.scala | 4 + .../spark/repl/SparkILoopInterpreter.scala | 103 ++++++++++++++++++ .../org/apache/spark/repl/ReplSuite.scala | 10 ++ 4 files changed, 191 insertions(+) create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala create mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000000..724ce9af49f77 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.interpreter.{ExprTyper, IR} + +trait SparkExprTyper extends ExprTyper { + + import repl._ + import global.{reporter => _, Import => _, _} + import naming.freshInternalVarName + + def doInterpret(code: String): IR.Result = { + // interpret/interpretSynthetic may change the phase, + // which would have unintended effects on types. + val savedPhase = phase + try interpretSynthetic(code) finally phase = savedPhase + } + + override def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + doInterpret(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + doInterpret(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + + def asError(): Symbol = { + doInterpret(code) + NoSymbol + } + + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 3ce7cc7c85f74..e69441a475e9a 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + override def createInterpreter(): Unit = { + intp = new SparkILoopInterpreter(settings, out) + } + val initializationCommands: Seq[String] = Seq( """ @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala new file mode 100644 index 0000000000000..0803426403af5 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter._ + +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { + self => + + override lazy val memberHandlers = new { + val intp: self.type = self + } with MemberHandlers { + import intp.global._ + + override def chooseHandler(member: intp.global.Tree): MemberHandler = member match { + case member: Import => new SparkImportHandler(member) + case _ => super.chooseHandler (member) + } + + class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) { + + override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match { + case NoSymbol => intp.typeOfExpression("" + expr) + case sym => sym.tpe + } + + private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s)) + private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx + private def pos(name: Name, s: String): Int = { + var i = name.pos(s.charAt(0), 0) + val sLen = s.length() + if (sLen == 1) return i + while (i + sLen <= name.length) { + var j = 1 + while (s.charAt(j) == name.charAt(i + j)) { + j += 1 + if (j == sLen) return i + } + i = name.pos(s.charAt(0), i + 1) + } + name.length + } + + private def isFlattenedSymbol(sym: Symbol): Boolean = + sym.owner.isPackageClass && + sym.name.containsName(nme.NAME_JOIN_STRING) && + sym.owner.info.member(sym.name.take( + safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol + + private def importableTargetMembers = + importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList + + def isIndividualImport(s: ImportSelector): Boolean = + s.name != nme.WILDCARD && s.rename != nme.WILDCARD + def isWildcardImport(s: ImportSelector): Boolean = + s.name == nme.WILDCARD + + // non-wildcard imports + private def individualSelectors = selectors filter isIndividualImport + + override val importsWildcard: Boolean = selectors exists isWildcardImport + + lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = { + val selectorRenameMap = + individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap + importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _)) + } + + override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1) + override lazy val wildcardSymbols: List[Symbol] = + if (importsWildcard) importableTargetMembers else Nil + + } + + } + + object expressionTyper extends { + val repl: SparkILoopInterpreter.this.type = self + } with SparkExprTyper { } + + override def symbolOfLine(code: String): global.Symbol = + expressionTyper.symbolOfLine(code) + + override def typeOfExpression(expr: String, silent: Boolean): global.Type = + expressionTyper.typeOfExpression(expr, silent) + +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 905b41cdc1594..a5053521f8e31 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error: not found: value sc", output) } + test("spark-shell should find imported types in class constructors and extends clause") { + val output = runInterpreter("local", + """ + |import org.apache.spark.Partition + |class P(p: Partition) + |class P(val index: Int) extends Partition + """.stripMargin) + assertDoesNotContain("error: not found: type Partition", output) + } + } From ee10ca7ec6cf7fbaab3f95a097b46936d97d0835 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 1 Dec 2017 13:02:03 -0800 Subject: [PATCH 332/712] [SPARK-22638][SS] Use a separate queue for StreamingQueryListenerBus ## What changes were proposed in this pull request? Use a separate Spark event queue for StreamingQueryListenerBus so that if there are many non-streaming events, streaming query listeners don't need to wait for other Spark listeners and can catch up. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #19838 from zsxwing/SPARK-22638. --- .../scala/org/apache/spark/scheduler/LiveListenerBus.scala | 4 +++- .../sql/execution/streaming/StreamingQueryListenerBus.scala | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 2f93c497c5771..23121402b1025 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -87,7 +87,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { * of each other (each one uses a separate thread for delivering events), allowing slower * listeners to be somewhat isolated from others. */ - private def addToQueue(listener: SparkListenerInterface, queue: String): Unit = synchronized { + private[spark] def addToQueue( + listener: SparkListenerInterface, + queue: String): Unit = synchronized { if (stopped.get()) { throw new IllegalStateException("LiveListenerBus is stopped.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 07e39023c8366..7dd491ede9d05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) import StreamingQueryListener._ - sparkListenerBus.addToSharedQueue(this) + sparkListenerBus.addToQueue(this, StreamingQueryListenerBus.STREAM_EVENT_QUERY) /** * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus @@ -130,3 +130,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } } + +object StreamingQueryListenerBus { + val STREAM_EVENT_QUERY = "streams" +} From aa4cf2b19e4cf5588af7e2192e0e9f687cd84bc5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 2 Dec 2017 11:55:43 +0900 Subject: [PATCH 333/712] [SPARK-22651][PYTHON][ML] Prevent initiating multiple Hive clients for ImageSchema.readImages ## What changes were proposed in this pull request? Calling `ImageSchema.readImages` multiple times as below in PySpark shell: ```python from pyspark.ml.image import ImageSchema data_path = 'data/mllib/images/kittens' _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() ``` throws an error as below: ``` ... org.datanucleus.exceptions.NucleusDataStoreException: Unable to open a test connection to the given database. JDBC url = jdbc:derby:;databaseName=metastore_db;create=true, username = APP. Terminating connection pool (set lazyInit to true if you expect to start your database after your app). Original Exception: ------ java.sql.SQLException: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$1742f639f, see the next exception for details. ... at org.apache.derby.jdbc.AutoloadedDriver.connect(Unknown Source) ... at org.apache.hadoop.hive.metastore.HiveMetaStore.newRetryingHMSHandler(HiveMetaStore.java:5762) ... at org.apache.spark.sql.hive.client.HiveClientImpl.newState(HiveClientImpl.scala:180) ... at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply$mcZ$sp(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$databaseExists$1.apply(HiveExternalCatalog.scala:195) at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:97) at org.apache.spark.sql.hive.HiveExternalCatalog.databaseExists(HiveExternalCatalog.scala:194) at org.apache.spark.sql.internal.SharedState.externalCatalog$lzycompute(SharedState.scala:100) at org.apache.spark.sql.internal.SharedState.externalCatalog(SharedState.scala:88) at org.apache.spark.sql.hive.HiveSessionStateBuilder.externalCatalog(HiveSessionStateBuilder.scala:39) at org.apache.spark.sql.hive.HiveSessionStateBuilder.catalog$lzycompute(HiveSessionStateBuilder.scala:54) at org.apache.spark.sql.hive.HiveSessionStateBuilder.catalog(HiveSessionStateBuilder.scala:52) at org.apache.spark.sql.hive.HiveSessionStateBuilder$$anon$1.(HiveSessionStateBuilder.scala:69) at org.apache.spark.sql.hive.HiveSessionStateBuilder.analyzer(HiveSessionStateBuilder.scala:69) at org.apache.spark.sql.internal.BaseSessionStateBuilder$$anonfun$build$2.apply(BaseSessionStateBuilder.scala:293) at org.apache.spark.sql.internal.BaseSessionStateBuilder$$anonfun$build$2.apply(BaseSessionStateBuilder.scala:293) at org.apache.spark.sql.internal.SessionState.analyzer$lzycompute(SessionState.scala:79) at org.apache.spark.sql.internal.SessionState.analyzer(SessionState.scala:79) at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:70) at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:68) at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:51) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:70) at org.apache.spark.sql.SparkSession.internalCreateDataFrame(SparkSession.scala:574) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:593) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:348) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:348) at org.apache.spark.ml.image.ImageSchema$$anonfun$readImages$2$$anonfun$apply$1.apply(ImageSchema.scala:253) ... Caused by: ERROR XJ040: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$1742f639f, see the next exception for details. at org.apache.derby.iapi.error.StandardException.newException(Unknown Source) at org.apache.derby.impl.jdbc.SQLExceptionFactory.wrapArgsForTransportAcrossDRDA(Unknown Source) ... 121 more Caused by: ERROR XSDB6: Another instance of Derby may have already booted the database /.../spark/metastore_db. ... Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/ml/image.py", line 190, in readImages dropImageFailures, float(sampleRatio), seed) File "/.../spark/python/lib/py4j-0.10.6-src.zip/py4j/java_gateway.py", line 1160, in __call__ File "/.../spark/python/pyspark/sql/utils.py", line 69, in deco raise AnalysisException(s.split(': ', 1)[1], stackTrace) pyspark.sql.utils.AnalysisException: u'java.lang.RuntimeException: java.lang.RuntimeException: Unable to instantiate org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient;' ``` Seems we better stick to `SparkSession.builder.getOrCreate()` like: https://github.com/apache/spark/blob/51620e288b5e0a7fffc3899c9deadabace28e6d7/python/pyspark/sql/streaming.py#L329 https://github.com/apache/spark/blob/dc5d34d8dcd6526d1dfdac8606661561c7576a62/python/pyspark/sql/column.py#L541 https://github.com/apache/spark/blob/33d43bf1b6f55594187066f0e38ba3985fa2542b/python/pyspark/sql/readwriter.py#L105 ## How was this patch tested? This was tested as below in PySpark shell: ```python from pyspark.ml.image import ImageSchema data_path = 'data/mllib/images/kittens' _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() _ = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True).collect() ``` Author: hyukjinkwon Closes #19845 from HyukjinKwon/SPARK-22651. --- python/pyspark/ml/image.py | 5 ++--- python/pyspark/ml/tests.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 2b61aa9c0d9e9..384599dc0c532 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -201,9 +201,8 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - ctx = SparkContext._active_spark_context - spark = SparkSession(ctx) - image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + spark = SparkSession.builder.getOrCreate() + image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession jresult = image_schema.readImages(path, jsession, recursive, numPartitions, dropImageFailures, float(sampleRatio), seed) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 89ef555cf3442..3a0b816c367ec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -67,7 +67,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql import DataFrame, Row, SparkSession, HiveContext from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * @@ -1855,6 +1855,35 @@ def test_read_images(self): lambda: ImageSchema.toImage("a")) +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + # Note that here we enable Hive's support. + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + except TypeError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + cls.spark = HiveContext._createForTesting(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.sparkSession.stop() + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From d2cf95aa63f5f5c9423f0455c2bfbee7833c9982 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Dec 2017 07:37:02 -0600 Subject: [PATCH 334/712] [SPARK-22634][BUILD] Update Bouncy Castle to 1.58 ## What changes were proposed in this pull request? Update Bouncy Castle to 1.58, and jets3t to 0.9.4 to (sort of) match. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19859 from srowen/SPARK-22634. --- dev/deps/spark-deps-hadoop-2.6 | 8 +++----- dev/deps/spark-deps-hadoop-2.7 | 8 +++----- pom.xml | 8 +++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8f508219c2ded..2c68b73095c4d 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -21,7 +21,7 @@ avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar -bcprov-jdk15on-1.51.jar +bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.7.jar -java-xmlbuilder-1.0.jar +java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.3.jar +jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar @@ -137,14 +137,12 @@ log4j-1.2.17.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar -mail-1.4.7.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar -mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 68e937f50b391..2aaac600b3ec3 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -21,7 +21,7 @@ avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar -bcprov-jdk15on-1.51.jar +bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.7.jar -java-xmlbuilder-1.0.jar +java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.3.jar +jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar @@ -138,14 +138,12 @@ log4j-1.2.17.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar -mail-1.4.7.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar -mx4j-3.0.2.jar netty-3.9.9.Final.jar netty-all-4.0.47.Final.jar objenesis-2.1.jar diff --git a/pom.xml b/pom.xml index 731ee86439eff..07bca9d267da0 100644 --- a/pom.xml +++ b/pom.xml @@ -141,7 +141,7 @@ 3.1.5 1.7.7 hadoop2 - 0.9.3 + 0.9.4 1.7.3 1.11.76 @@ -985,6 +985,12 @@ + + org.bouncycastle + bcprov-jdk15on + + 1.58 + org.apache.hadoop hadoop-yarn-api From f23dddf105aef88531b3572ad70889cf2fc300c9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Dec 2017 22:21:44 +0800 Subject: [PATCH 335/712] [SPARK-20682][SPARK-15474][SPARK-21791] Add new ORCFileFormat based on ORC 1.4.1 ## What changes were proposed in this pull request? Since [SPARK-2883](https://issues.apache.org/jira/browse/SPARK-2883), Apache Spark supports Apache ORC inside `sql/hive` module with Hive dependency. This PR aims to add a new ORC data source inside `sql/core` and to replace the old ORC data source eventually. This PR resolves the following three issues. - [SPARK-20682](https://issues.apache.org/jira/browse/SPARK-20682): Add new ORCFileFormat based on Apache ORC 1.4.1 - [SPARK-15474](https://issues.apache.org/jira/browse/SPARK-15474): ORC data source fails to write and read back empty dataframe - [SPARK-21791](https://issues.apache.org/jira/browse/SPARK-21791): ORC should support column names with dot ## How was this patch tested? Pass the Jenkins with the existing all tests and new tests for SPARK-15474 and SPARK-21791. Author: Dongjoon Hyun Author: Wenchen Fan Closes #19651 from dongjoon-hyun/SPARK-20682. --- .../datasources/orc/OrcDeserializer.scala | 243 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 139 +++++++++- .../datasources/orc/OrcFilters.scala | 210 +++++++++++++++ .../datasources/orc/OrcOutputWriter.scala | 53 ++++ .../datasources/orc/OrcSerializer.scala | 228 ++++++++++++++++ .../execution/datasources/orc/OrcUtils.scala | 113 ++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 25 ++ 7 files changed, 1009 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala new file mode 100644 index 0000000000000..4ecc54bd2fd96 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize ORC structs to Spark rows. + */ +class OrcDeserializer( + dataSchema: StructType, + requiredSchema: StructType, + requestedColIds: Array[Int]) { + + private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + + private val fieldWriters: Array[WritableComparable[_] => Unit] = { + requiredSchema.zipWithIndex + // The value of missing columns are always null, do not need writers. + .filterNot { case (_, index) => requestedColIds(index) == -1 } + .map { case (f, index) => + val writer = newWriter(f.dataType, new RowUpdater(resultRow)) + (value: WritableComparable[_]) => writer(index, value) + }.toArray + } + + private val validColIds = requestedColIds.filterNot(_ == -1) + + def deserialize(orcStruct: OrcStruct): InternalRow = { + var i = 0 + while (i < validColIds.length) { + val value = orcStruct.getFieldValue(validColIds(i)) + if (value == null) { + resultRow.setNullAt(i) + } else { + fieldWriters(i)(value) + } + i += 1 + } + resultRow + } + + /** + * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. + */ + private def newWriter( + dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit = + dataType match { + case NullType => (ordinal, _) => + updater.setNullAt(ordinal) + + case BooleanType => (ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => (ordinal, value) => + updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => (ordinal, value) => + updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => (ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => (ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => (ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => (ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case StringType => (ordinal, value) => + updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) + + case BinaryType => (ordinal, value) => + val binary = value.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + updater.set(ordinal, bytes) + + case DateType => (ordinal, value) => + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + + case TimestampType => (ordinal, value) => + updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => (ordinal, value) => + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + updater.set(ordinal, v) + + case st: StructType => (ordinal, value) => + val result = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(result) + val fieldConverters = st.map(_.dataType).map { dt => + newWriter(dt, fieldUpdater) + }.toArray + val orcStruct = value.asInstanceOf[OrcStruct] + + var i = 0 + while (i < st.length) { + val value = orcStruct.getFieldValue(i) + if (value == null) { + result.setNullAt(i) + } else { + fieldConverters(i)(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case ArrayType(elementType, _) => (ordinal, value) => + val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]] + val length = orcArray.size() + val result = createArrayData(elementType, length) + val elementUpdater = new ArrayDataUpdater(result) + val elementConverter = newWriter(elementType, elementUpdater) + + var i = 0 + while (i < length) { + val value = orcArray.get(i) + if (value == null) { + result.setNullAt(i) + } else { + elementConverter(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case MapType(keyType, valueType, _) => (ordinal, value) => + val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + val length = orcMap.size() + val keyArray = createArrayData(keyType, length) + val keyUpdater = new ArrayDataUpdater(keyArray) + val keyConverter = newWriter(keyType, keyUpdater) + val valueArray = createArrayData(valueType, length) + val valueUpdater = new ArrayDataUpdater(valueArray) + val valueConverter = newWriter(valueType, valueUpdater) + + var i = 0 + val it = orcMap.entrySet().iterator() + while (it.hasNext) { + val entry = it.next() + keyConverter(i, entry.getKey) + val value = entry.getValue + if (value == null) { + valueArray.setNullAt(i) + } else { + valueConverter(i, value) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 215740e90fe84..75c42213db3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -17,10 +17,29 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.orc.TypeDescription +import java.io._ +import java.net.URI +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc._ +import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce._ + +import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration private[sql] object OrcFileFormat { private def checkFieldName(name: String): Unit = { @@ -39,3 +58,119 @@ private[sql] object OrcFileFormat { names.foreach(checkFieldName) } } + +/** + * New ORC File Format based on Apache ORC. + */ +class OrcFileFormat + extends FileFormat + with DataSourceRegister + with Serializable { + + override def shortName(): String = "orc" + + override def toString: String = "ORC" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat] + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcUtils.readSchema(sparkSession, files) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + + val conf = job.getConfiguration + + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + + conf.asInstanceOf[JobConf] + .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) + } + } + + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + + if (requestedColIdsOrEmptyFile.isEmpty) { + Iterator.empty + } else { + val requestedColIds = requestedColIdsOrEmptyFile.get + assert(requestedColIds.length == requiredSchema.length, + "[BUG] requested column IDs do not match required schema") + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + requestedColIds.filter(_ != -1).sorted.mkString(",")) + + val fileSplit = + new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val unsafeProjection = UnsafeProjection.create(requiredSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala new file mode 100644 index 0000000000000..cec256cc1b498 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable + +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types._ + +/** + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. + */ +private[orc] object OrcFilters { + + /** + * Create ORC filter as a SearchArgument instance. + */ + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + } yield filter + + for { + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() + } + + /** + * Return true if this is a searchable type in ORC. + * Both CharType and VarcharType are cleaned at AstBuilder. + */ + private def isSearchableType(dataType: DataType) = dataType match { + // TODO: SPARK-21787 Support for pushing down filters for DateType in ORC + case BinaryType | DateType => false + case _: AtomicType => true + case _ => false + } + + /** + * Get PredicateLeafType which is corresponding to the given DataType. + */ + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + } + + /** + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + val decimal = value.asInstanceOf[java.math.BigDecimal] + val decimalWritable = new HiveDecimalWritable(decimal.longValue) + decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) + decimalWritable + case _ => value + } + + /** + * Build a SearchArgument and return the builder so far. + */ + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { + def newBuilder = SearchArgumentFactory.newBuilder() + + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + + import org.apache.spark.sql.sources._ + + expression match { + case And(left, right) => + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Or(left, right) => + for { + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(dataTypeMap, child, newBuilder) + negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) + } yield negate.end() + + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end()) + + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end()) + + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end()) + + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end()) + + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end()) + + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startAnd().isNull(attribute, getType(attribute)).end()) + + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + Some(builder.startNot().isNull(attribute, getType(attribute)).end()) + + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(attribute, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) + + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala new file mode 100644 index 0000000000000..84755bfa301f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcOutputFormat + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val serializer = new OrcSerializer(dataSchema) + + private val recordWriter = { + new OrcOutputFormat[OrcStruct]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + }.getRecordWriter(context) + } + + override def write(row: InternalRow): Unit = { + recordWriter.write(NullWritable.get(), serializer.serialize(row)) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala new file mode 100644 index 0000000000000..899af0750cadf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.TypeDescription +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize Spark rows to ORC structs. + */ +class OrcSerializer(dataSchema: StructType) { + + private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct] + private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray + + def serialize(row: InternalRow): OrcStruct = { + var i = 0 + while (i < converters.length) { + if (row.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, converters(i)(row, i)) + } + i += 1 + } + result + } + + private type Converter = (SpecializedGetters, Int) => WritableComparable[_] + + /** + * Creates a converter to convert Catalyst data at the given ordinal to ORC values. + */ + private def newConverter( + dataType: DataType, + reuseObj: Boolean = true): Converter = dataType match { + case NullType => (getter, ordinal) => null + + case BooleanType => + if (reuseObj) { + val result = new BooleanWritable() + (getter, ordinal) => + result.set(getter.getBoolean(ordinal)) + result + } else { + (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal)) + } + + case ByteType => + if (reuseObj) { + val result = new ByteWritable() + (getter, ordinal) => + result.set(getter.getByte(ordinal)) + result + } else { + (getter, ordinal) => new ByteWritable(getter.getByte(ordinal)) + } + + case ShortType => + if (reuseObj) { + val result = new ShortWritable() + (getter, ordinal) => + result.set(getter.getShort(ordinal)) + result + } else { + (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) + } + + case IntegerType => + if (reuseObj) { + val result = new IntWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new IntWritable(getter.getInt(ordinal)) + } + + + case LongType => + if (reuseObj) { + val result = new LongWritable() + (getter, ordinal) => + result.set(getter.getLong(ordinal)) + result + } else { + (getter, ordinal) => new LongWritable(getter.getLong(ordinal)) + } + + case FloatType => + if (reuseObj) { + val result = new FloatWritable() + (getter, ordinal) => + result.set(getter.getFloat(ordinal)) + result + } else { + (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal)) + } + + case DoubleType => + if (reuseObj) { + val result = new DoubleWritable() + (getter, ordinal) => + result.set(getter.getDouble(ordinal)) + result + } else { + (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal)) + } + + + // Don't reuse the result object for string and binary as it would cause extra data copy. + case StringType => (getter, ordinal) => + new Text(getter.getUTF8String(ordinal).getBytes) + + case BinaryType => (getter, ordinal) => + new BytesWritable(getter.getBinary(ordinal)) + + case DateType => + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) + } + + // The following cases are already expensive, reusing object or not doesn't matter. + + case TimestampType => (getter, ordinal) => + val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) + val result = new OrcTimestamp(ts.getTime) + result.setNanos(ts.getNanos) + result + + case DecimalType.Fixed(precision, scale) => (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + + case st: StructType => (getter, ordinal) => + val result = createOrcValue(st).asInstanceOf[OrcStruct] + val fieldConverters = st.map(_.dataType).map(newConverter(_)) + val numFields = st.length + val struct = getter.getStruct(ordinal, numFields) + var i = 0 + while (i < numFields) { + if (struct.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, fieldConverters(i)(struct, i)) + } + i += 1 + } + result + + case ArrayType(elementType, _) => (getter, ordinal) => + val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val elementConverter = newConverter(elementType, reuseObj = false) + val array = getter.getArray(ordinal) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(array, i)) + } + i += 1 + } + result + + case MapType(keyType, valueType, _) => (getter, ordinal) => + val result = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val keyConverter = newConverter(keyType, reuseObj = false) + val valueConverter = newConverter(valueType, reuseObj = false) + val map = getter.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + val key = keyConverter(keyArray, i) + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case udt: UserDefinedType[_] => newConverter(udt.sqlType) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private def createOrcValue(dataType: DataType) = { + OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala new file mode 100644 index 0000000000000..b03ee06d04a16 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.orc.{OrcFile, TypeDescription} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types._ + +object OrcUtils extends Logging { + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDirectory) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + paths + } + + def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } + + def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) + : Option[StructType] = { + val conf = sparkSession.sessionState.newHadoopConf() + // TODO: We need to support merge schema. Please see SPARK-11412. + files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + } + } + + /** + * Returns the requested column ids from the given ORC file. Column id can be -1, which means the + * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty. + */ + def requestedColumnIds( + isCaseSensitive: Boolean, + dataSchema: StructType, + requiredSchema: StructType, + file: Path, + conf: Configuration): Option[Array[Int]] = { + val fs = file.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(file, readerOptions) + val orcFieldNames = reader.getSchema.getFieldNames.asScala + if (orcFieldNames.isEmpty) { + // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. + None + } else { + if (orcFieldNames.forall(_.startsWith("_col"))) { + // This is a ORC file written by Hive, no field names in the physical schema, assume the + // physical schema maps to the data scheme by index. + assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + + s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + "no idea which columns were dropped, fail to read.") + Some(requiredSchema.fieldNames.map { name => + val index = dataSchema.fieldIndex(name) + if (index < orcFieldNames.length) { + index + } else { + -1 + } + }) + } else { + val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution + Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5d0bba69daca1..31d9b909ad463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,4 +2757,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + // Only New OrcFileFormat supports this + Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, + "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF.limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + test("SPARK-21791 ORC should support column names with dot") { + val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path) + assert(spark.read.format(orc).load(path).collect().length == 2) + } + } } From 2c16267f7ca392d717942a7654e90db60ba60770 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 3 Dec 2017 22:56:03 +0800 Subject: [PATCH 336/712] [SPARK-22669][SQL] Avoid unnecessary function calls in code generation ## What changes were proposed in this pull request? In many parts of the codebase for code generation, we are splitting the code to avoid exceptions due to the 64KB method size limit. This is generating a lot of methods which are called every time, even though sometime this is not needed. As pointed out here: https://github.com/apache/spark/pull/19752#discussion_r153081547, this is a not negligible overhead which can be avoided. The PR applies the same approach used in #19752 also to the other places where this was feasible. ## How was this patch tested? existing UTs. Author: Marco Gaido Closes #19860 from mgaido91/SPARK-22669. --- .../expressions/nullExpressions.scala | 141 ++++++++++++------ .../sql/catalyst/expressions/predicates.scala | 68 ++++++--- 2 files changed, 140 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 173e171910b69..3b52a0efd404a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -75,23 +75,51 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) ctx.addMutableState(ctx.javaType(dataType), ev.value) + // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) s""" - if (${ev.isNull}) { - ${eval.code} - if (!${eval.isNull}) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - } - """ + |${eval.code} + |if (!${eval.isNull}) { + | ${ev.isNull} = false; + | ${ev.value} = ${eval.value}; + | continue; + |} + """.stripMargin } + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + evals.mkString("\n") + } else { + ctx.splitExpressions(evals, "coalesce", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + makeSplitFunction = { + func => + s""" + |do { + | $func + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + |$funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }.mkString + }) + } - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.splitExpressions(evals)}""") + ev.copy(code = + s""" + |${ev.isNull} = true; + |${ev.value} = ${ctx.defaultValue(dataType)}; + |do { + | $code + |} while (false); + """.stripMargin) } } @@ -358,53 +386,70 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") + // all evals are meant to be inside a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull} && !Double.isNaN(${eval.value})) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull} && !Double.isNaN(${eval.value})) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin case _ => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull}) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull}) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin } } val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions( - expressions = evals, - funcName = "atLeastNNonNulls", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, - returnType = "int", - makeSplitFunction = { body => - s""" - $body - return $nonnull; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n") - } - ) - } + evals.mkString("\n") + } else { + ctx.splitExpressions( + expressions = evals, + funcName = "atLeastNNonNulls", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil, + returnType = ctx.JAVA_INT, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + |return $nonnull; + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin).mkString("\n") + } + ) + } - ev.copy(code = s""" - int $nonnull = 0; - $code - boolean ${ev.value} = $nonnull >= $n;""", isNull = "false") + ev.copy(code = + s""" + |${ctx.JAVA_INT} $nonnull = 0; + |do { + | $code + |} while (false); + |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1aaaaf1db48d1..75cc9b3bd8045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -234,36 +234,62 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val valueArg = ctx.freshName("valueArg") + // All the blocks are meant to be inside a do { ... } while (false); loop. + // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => s""" - if (!${ev.value}) { - ${x.code} - if (${x.isNull}) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - ${ev.isNull} = false; - ${ev.value} = true; + |${x.code} + |if (${x.isNull}) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | continue; + |} + """.stripMargin) + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + listCode.mkString("\n") + } else { + ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$funcCall; + |if (${ev.value}) { + | continue; + |} + """.stripMargin).mkString("\n") } - } - """) - val listCodes = ctx.splitExpressions( - expressions = listCode, - funcName = "valueIn", - extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) - ev.copy(code = s""" - ${valueGen.code} - ${ev.value} = false; - ${ev.isNull} = ${valueGen.isNull}; - if (!${ev.isNull}) { - ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value}; - $listCodes + ) } - """) + ev.copy(code = + s""" + |${valueGen.code} + |${ev.value} = false; + |${ev.isNull} = ${valueGen.isNull}; + |if (!${ev.isNull}) { + | $javaDataType $valueArg = ${valueGen.value}; + | do { + | $code + | } while (false); + |} + """.stripMargin) } override def sql: String = { From dff440f1ecdbce65cad377c44fd2abfd4eff9b44 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 3 Dec 2017 23:05:31 +0800 Subject: [PATCH 337/712] [SPARK-22626][SQL] deals with wrong Hive's statistics (zero rowCount) This pr to ensure that the Hive's statistics `totalSize` (or `rawDataSize`) > 0, `rowCount` also must be > 0. Otherwise may cause OOM when CBO is enabled. unit tests Author: Yuming Wang Closes #19831 from wangyum/SPARK-22626. --- .../sql/hive/client/HiveClientImpl.scala | 6 +++--- .../spark/sql/hive/StatisticsSuite.scala | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 47ce6ba83866c..77e836003b39f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -418,7 +418,7 @@ private[hive] class HiveClientImpl( // Note that this statistics could be overridden by Spark's statistics if that's available. val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) - val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0) + val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. @@ -430,9 +430,9 @@ private[hive] class HiveClientImpl( // so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, // return None. Later, we will use the other ways to estimate the statistics. if (totalSize.isDefined && totalSize.get > 0L) { - Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount)) + Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0))) } else if (rawDataSize.isDefined && rawDataSize.get > 0) { - Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount)) + Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount.filter(_ > 0))) } else { // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0cdd9305c6b6f..ee027e5308265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1360,4 +1360,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + + test("Deals with wrong Hive's statistics (zero rowCount)") { + withTable("maybe_big") { + sql("CREATE TABLE maybe_big (c1 bigint)" + + "TBLPROPERTIES ('numRows'='0', 'rawDataSize'='60000000000', 'totalSize'='8000000000000')") + + val relation = spark.table("maybe_big").queryExecution.analyzed.children.head + .asInstanceOf[HiveTableRelation] + + val properties = relation.tableMeta.ignoredProperties + assert(properties("totalSize").toLong > 0) + assert(properties("rawDataSize").toLong > 0) + assert(properties("numRows").toLong == 0) + + assert(relation.stats.sizeInBytes > 0) + // May be cause OOM if rowCount == 0 when enables CBO, see SPARK-22626 for details. + assert(relation.stats.rowCount.isEmpty) + } + } } From 4131ad03f4d2dfcfb1e166e5dfdf0752479f7340 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 3 Dec 2017 23:52:37 -0800 Subject: [PATCH 338/712] [SPARK-22489][DOC][FOLLOWUP] Update broadcast behavior changes in migration section ## What changes were proposed in this pull request? Update broadcast behavior changes in migration section. ## How was this patch tested? N/A Author: Yuming Wang Closes #19858 from wangyum/SPARK-22489-migration. --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1b9c3bbfd059..b76be9132dd03 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,8 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. + + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). ## Upgrading From Spark SQL 2.1 to 2.2 From 3927bb9b460d2d944ecf3c8552d71e8a25d29655 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 4 Dec 2017 11:07:27 -0600 Subject: [PATCH 339/712] [SPARK-22473][FOLLOWUP][TEST] Remove deprecated Date functions ## What changes were proposed in this pull request? #19696 replaced the deprecated usages for `Date` and `Waiter`, but a few methods were missed. The PR fixes the forgotten deprecated usages. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #19875 from mgaido91/SPARK-22473_FOLLOWUP. --- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 865092a659f26..0079e4e8d6f74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -291,26 +291,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val udt = new ExamplePointUDT private val smallValues = - Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"), new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val largeValues = - Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2), + Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), Date.valueOf("2000-01-02"), new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")), Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))), Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt)) private val equalValues1 = - Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"), new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val equalValues2 = - Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), Date.valueOf("2000-01-01"), new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), From f81401e1cb39f2d6049b79dc8d61305f3371276f Mon Sep 17 00:00:00 2001 From: Reza Safi Date: Mon, 4 Dec 2017 09:23:48 -0800 Subject: [PATCH 340/712] [SPARK-22162] Executors and the driver should use consistent JobIDs in the RDD commit protocol I have modified SparkHadoopWriter so that executors and the driver always use consistent JobIds during the hadoop commit. Before SPARK-18191, spark always used the rddId, it just incorrectly named the variable stageId. After SPARK-18191, it used the rddId as the jobId on the driver's side, and the stageId as the jobId on the executors' side. With this change executors and the driver will consistently uses rddId as the jobId. Also with this change, during the hadoop commit protocol spark uses actual stageId to check whether a stage can be committed unlike before that it was using executors' jobId to do this check. In addition to the existing unit tests, a test has been added to check whether executors and the driver are using the same JobId. The test failed before this change and passed after applying this fix. Author: Reza Safi Closes #19848 from rezasafi/stagerddsimple. --- .../spark/internal/io/SparkHadoopWriter.scala | 12 ++--- .../spark/mapred/SparkHadoopMapRedUtil.scala | 5 ++- .../spark/rdd/PairRDDFunctionsSuite.scala | 44 +++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 949d8c677998f..abf39213fa0d2 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -60,17 +60,17 @@ object SparkHadoopWriter extends Logging { config: HadoopWriteConfigUtil[K, V]): Unit = { // Extract context and configuration from RDD. val sparkContext = rdd.context - val stageId = rdd.id + val commitJobId = rdd.id // Set up a job. val jobTrackerId = createJobTrackerID(new Date()) - val jobContext = config.createJobContext(jobTrackerId, stageId) + val jobContext = config.createJobContext(jobTrackerId, commitJobId) config.initOutputFormat(jobContext) // Assert the output format/key/value class is set in JobConf. config.assertConf(jobContext, rdd.conf) - val committer = config.createCommitter(stageId) + val committer = config.createCommitter(commitJobId) committer.setupJob(jobContext) // Try to write all RDD partitions as a Hadoop OutputFormat. @@ -80,7 +80,7 @@ object SparkHadoopWriter extends Logging { context = context, config = config, jobTrackerId = jobTrackerId, - sparkStageId = context.stageId, + commitJobId = commitJobId, sparkPartitionId = context.partitionId, sparkAttemptNumber = context.attemptNumber, committer = committer, @@ -102,14 +102,14 @@ object SparkHadoopWriter extends Logging { context: TaskContext, config: HadoopWriteConfigUtil[K, V], jobTrackerId: String, - sparkStageId: Int, + commitJobId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, iterator: Iterator[(K, V)]): TaskCommitMessage = { // Set up a task. val taskContext = config.createTaskAttemptContext( - jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber) + jobTrackerId, commitJobId, sparkPartitionId, sparkAttemptNumber) committer.setupTask(taskContext) val (outputMetrics, callback) = initHadoopOutputMetrics(context) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 607283a306b8f..764735dc4eae7 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -70,7 +70,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator val taskAttemptNumber = TaskContext.get().attemptNumber() - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) + val stageId = TaskContext.get().stageId() + val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -80,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0a248b6064ee8..65d35264dc108 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job => NewJob, JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable +import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.Partitioner @@ -524,6 +525,15 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } + test("The JobId on the driver and executors should be the same during the commit") { + // Create more than one rdd to mimic stageId not equal to rddId + val pairs = sc.parallelize(Array((1, 2), (2, 3)), 2) + .map { p => (new Integer(p._1 + 1), new Integer(p._2 + 1)) } + .filter { p => p._1 > 0 } + pairs.saveAsNewAPIHadoopFile[YetAnotherFakeFormat]("ignored") + assert(JobID.jobid != -1) + } + test("saveAsHadoopFile should respect configured output committers") { val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) val conf = new JobConf() @@ -908,6 +918,40 @@ class NewFakeFormatWithCallback() extends NewFakeFormat { } } +class YetAnotherFakeCommitter extends NewOutputCommitter with Assertions { + def setupJob(j: NewJobContext): Unit = { + JobID.jobid = j.getJobID().getId + } + + def needsTaskCommit(t: NewTaskAttempContext): Boolean = false + + def setupTask(t: NewTaskAttempContext): Unit = { + val jobId = t.getTaskAttemptID().getJobID().getId + assert(jobId === JobID.jobid) + } + + def commitTask(t: NewTaskAttempContext): Unit = {} + + def abortTask(t: NewTaskAttempContext): Unit = {} +} + +class YetAnotherFakeFormat() extends NewOutputFormat[Integer, Integer]() { + + def checkOutputSpecs(j: NewJobContext): Unit = {} + + def getRecordWriter(t: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { + new NewFakeWriter() + } + + def getOutputCommitter(t: NewTaskAttempContext): NewOutputCommitter = { + new YetAnotherFakeCommitter() + } +} + +object JobID { + var jobid = -1 +} + class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false From e1dd03e42c2131b167b1e80c761291e88bfdf03f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 4 Dec 2017 11:05:03 -0800 Subject: [PATCH 341/712] [SPARK-22372][CORE, YARN] Make cluster submission use SparkApplication. The main goal of this change is to allow multiple cluster-mode submissions from the same JVM, without having them end up with mixed configuration. That is done by extending the SparkApplication trait, and doing so was reasonably trivial for standalone and mesos modes. For YARN mode, there was a complication. YARN used a "SPARK_YARN_MODE" system property to control behavior indirectly in a whole bunch of places, mainly in the SparkHadoopUtil / YarnSparkHadoopUtil classes. Most of the changes here are removing that. Since we removed support for Hadoop 1.x, some methods that lived in YarnSparkHadoopUtil can now live in SparkHadoopUtil. The remaining methods don't need to be part of the class, and can be called directly from the YarnSparkHadoopUtil object, so now there's a single implementation of SparkHadoopUtil. There were two places in the code that relied on SPARK_YARN_MODE to make decisions about YARN-specific functionality, and now explicitly check the master from the configuration for that instead: * fetching the external shuffle service port, which can come from the YARN configuration. * propagation of the authentication secret using Hadoop credentials. This also was cleaned up a little to not need so many methods in `SparkHadoopUtil`. With those out of the way, actually changing the YARN client to extend SparkApplication was easy. Tested with existing unit tests, and also by running YARN apps with auth and kerberos both on and off in a real cluster. Author: Marcelo Vanzin Closes #19631 from vanzin/SPARK-22372. --- .../org/apache/spark/SecurityManager.scala | 102 ++++++------- .../scala/org/apache/spark/SparkContext.scala | 3 - .../scala/org/apache/spark/SparkEnv.scala | 4 + .../org/apache/spark/deploy/Client.scala | 8 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 48 +------ .../org/apache/spark/deploy/SparkSubmit.scala | 31 ++-- .../deploy/rest/RestSubmissionClient.scala | 31 ++-- .../CoarseGrainedExecutorBackend.scala | 10 +- .../scala/org/apache/spark/util/Utils.scala | 6 +- .../apache/spark/SecurityManagerSuite.scala | 31 +++- .../spark/deploy/SparkSubmitSuite.scala | 8 +- .../rest/StandaloneRestSubmitSuite.scala | 2 +- project/MimaExcludes.scala | 6 + .../spark/deploy/yarn/ApplicationMaster.scala | 54 +++---- .../org/apache/spark/deploy/yarn/Client.scala | 65 ++++----- .../spark/deploy/yarn/YarnRMClient.scala | 2 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 135 ++++++------------ .../yarn/security/AMCredentialRenewer.scala | 2 +- .../cluster/YarnClientSchedulerBackend.scala | 4 +- .../cluster/YarnClusterSchedulerBackend.scala | 2 +- .../deploy/yarn/BaseYarnClusterSuite.scala | 5 - .../spark/deploy/yarn/ClientSuite.scala | 31 +--- .../spark/deploy/yarn/YarnClusterSuite.scala | 20 ++- .../yarn/YarnSparkHadoopUtilSuite.scala | 49 +------ ...ARNHadoopDelegationTokenManagerSuite.scala | 11 +- 25 files changed, 274 insertions(+), 396 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2480e56b72ccf..4c1dbe3ffb4ad 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} +import java.nio.charset.StandardCharsets.UTF_8 import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate import javax.net.ssl._ @@ -26,10 +27,11 @@ import javax.net.ssl._ import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.sasl.SecretKeyHolder import org.apache.spark.util.Utils @@ -225,7 +227,6 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); - private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -416,50 +417,6 @@ private[spark] class SecurityManager( def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey - /** - * Generates or looks up the secret key. - * - * The way the key is stored depends on the Spark deployment mode. Yarn - * uses the Hadoop UGI. - * - * For non-Yarn deployments, If the config variable is not set - * we throw an exception. - */ - private def generateSecretKey(): String = { - if (!isAuthenticationEnabled) { - null - } else if (SparkHadoopUtil.get.isYarnMode) { - // In YARN mode, the secure cookie will be created by the driver and stashed in the - // user's credentials, where executors can get it. The check for an array of size 0 - // is because of the test code in YarnSparkHadoopUtilSuite. - val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY) - if (secretKey == null || secretKey.length == 0) { - logDebug("generateSecretKey: yarn mode, secret key from credentials is null") - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secret = new Array[Byte](length) - rnd.nextBytes(secret) - - val cookie = HashCodes.fromBytes(secret).toString() - SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie) - cookie - } else { - new Text(secretKey).toString - } - } else { - // user must have set spark.authenticate.secret config - // For Master/Worker, auth secret is in conf; for Executors, it is in env variable - Option(sparkConf.getenv(SecurityManager.ENV_AUTH_SECRET)) - .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { - case Some(value) => value - case None => - throw new IllegalArgumentException( - "Error: a secret key must be specified via the " + - SecurityManager.SPARK_AUTH_SECRET_CONF + " config") - } - } - } - /** * Check to see if Acls for the UI are enabled * @return true if UI authentication is enabled, otherwise false @@ -542,7 +499,51 @@ private[spark] class SecurityManager( * Gets the secret key. * @return the secret key as a String if authentication is enabled, otherwise returns null */ - def getSecretKey(): String = secretKey + def getSecretKey(): String = { + if (isAuthenticationEnabled) { + val creds = UserGroupInformation.getCurrentUser().getCredentials() + Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) + .map { bytes => new String(bytes, UTF_8) } + .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) + .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) + .getOrElse { + throw new IllegalArgumentException( + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config") + } + } else { + null + } + } + + /** + * Initialize the authentication secret. + * + * If authentication is disabled, do nothing. + * + * In YARN mode, generate a new secret and store it in the current user's credentials. + * + * In other modes, assert that the auth secret is set in the configuration. + */ + def initializeAuth(): Unit = { + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { + return + } + + if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return + } + + val rnd = new SecureRandom() + val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE + val secretBytes = new Array[Byte](length) + rnd.nextBytes(secretBytes) + + val creds = new Credentials() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + UserGroupInformation.getCurrentUser().addCredentials(creds) + } // Default SecurityManager only has a single secret key, so ignore appId. override def getSaslUser(appId: String): String = getSaslUser() @@ -551,13 +552,12 @@ private[spark] class SecurityManager( private[spark] object SecurityManager { - val SPARK_AUTH_CONF: String = "spark.authenticate" - val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" + val SPARK_AUTH_CONF = NETWORK_AUTH_ENABLED.key + val SPARK_AUTH_SECRET_CONF = "spark.authenticate.secret" // This is used to set auth secret to an executor's env variable. It should have the same // value as SPARK_AUTH_SECRET_CONF set in SparkConf val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" // key used to store the spark secret in the Hadoop UGI - val SECRET_LOOKUP_KEY = "sparkCookie" - + val SECRET_LOOKUP_KEY = new Text("sparkCookie") } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c174939ca2e54..71f1e7c7321bc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -413,8 +413,6 @@ class SparkContext(config: SparkConf) extends Logging { } } - if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - _listenerBus = new LiveListenerBus(_conf) // Initialize the app status store and listener before SparkEnv is created so that it gets @@ -1955,7 +1953,6 @@ class SparkContext(config: SparkConf) extends Logging { // `SparkContext` is stopped. localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. - System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 24928150315e8..72123f2232532 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -234,6 +234,10 @@ object SparkEnv extends Logging { } val securityManager = new SecurityManager(conf, ioEncryptionKey) + if (isDriver) { + securityManager.initializeAuth() + } + ioEncryptionKey.foreach { _ => if (!securityManager.isEncryptionEnabled()) { logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 7acb5c55bb252..d5145094ec079 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -217,8 +217,13 @@ object Client { println("Use ./bin/spark-submit with \"--master spark://host:port\"") } // scalastyle:on println + new ClientApp().start(args, new SparkConf()) + } +} - val conf = new SparkConf() +private[spark] class ClientApp extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { val driverArgs = new ClientArguments(args) if (!conf.contains("spark.rpc.askTimeout")) { @@ -235,4 +240,5 @@ object Client { rpcEnv.awaitTermination() } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 17c7319b40f24..e14f9845e6db6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -75,9 +75,7 @@ class SparkHadoopUtil extends Logging { } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens.asScala) { - dest.addToken(token) - } + dest.addCredentials(source.getCredentials()) } /** @@ -120,16 +118,9 @@ class SparkHadoopUtil extends Logging { * Add any user credentials to the job conf which are necessary for running on a secure Hadoop * cluster. */ - def addCredentials(conf: JobConf) {} - - def isYarnMode(): Boolean = { false } - - def addSecretKeyToUserCredentials(key: String, secret: String) {} - - def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } - - def getCurrentUserCredentials(): Credentials = { - UserGroupInformation.getCurrentUser().getCredentials() + def addCredentials(conf: JobConf): Unit = { + val jobCreds = conf.getCredentials() + jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } def addCurrentUserCredentials(creds: Credentials): Unit = { @@ -328,17 +319,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Start a thread to periodically update the current user's credentials with new credentials so - * that access to secured service does not fail. - */ - private[spark] def startCredentialUpdater(conf: SparkConf) {} - - /** - * Stop the thread that does the credential updates. - */ - private[spark] def stopCredentialUpdater() {} - /** * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. @@ -441,14 +421,7 @@ class SparkHadoopUtil extends Logging { object SparkHadoopUtil { - private lazy val hadoop = new SparkHadoopUtil - private lazy val yarn = try { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .newInstance() - .asInstanceOf[SparkHadoopUtil] - } catch { - case e: Exception => throw new SparkException("Unable to load YARN support", e) - } + private lazy val instance = new SparkHadoopUtil val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" @@ -462,16 +435,7 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 - def get: SparkHadoopUtil = { - // Check each time to support changing to/from YARN - val yarnMode = java.lang.Boolean.parseBoolean( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if (yarnMode) { - yarn - } else { - hadoop - } - } + def get: SparkHadoopUtil = instance /** * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 73b956ef3e470..cfcdce648d330 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -92,6 +92,12 @@ object SparkSubmit extends CommandLineUtils with Logging { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + // scalastyle:off println private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -281,7 +287,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } // Make sure YARN is included in our build if we're trying to use it - if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { + if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { printErrorAndExit( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") @@ -363,10 +369,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull - // This security manager will not need an auth secret, but set a dummy value in case - // spark.authenticate is enabled, otherwise an exception is thrown. - lazy val downloadConf = sparkConf.clone().set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - lazy val secMgr = new SecurityManager(downloadConf) + lazy val secMgr = new SecurityManager(sparkConf) // In client mode, download remote files. var localPrimaryResource: String = null @@ -374,13 +377,13 @@ object SparkSubmit extends CommandLineUtils with Logging { var localPyFiles: String = null if (deployMode == CLIENT) { localPrimaryResource = Option(args.primaryResource).map { - downloadFile(_, targetDir, downloadConf, hadoopConf, secMgr) + downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) }.orNull localJars = Option(args.jars).map { - downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) }.orNull localPyFiles = Option(args.pyFiles).map { - downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) }.orNull } @@ -391,8 +394,6 @@ object SparkSubmit extends CommandLineUtils with Logging { // For yarn client mode, since we already download them with above code, so we only need to // figure out the local path and replace the remote one. if (clusterManager == YARN) { - sparkConf.setIfMissing(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val secMgr = new SecurityManager(sparkConf) val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { @@ -409,7 +410,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (file.exists()) { file.toURI.toString } else { - downloadFile(resource, targetDir, downloadConf, hadoopConf, secMgr) + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) } case _ => uri.toString } @@ -634,11 +635,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childMainClass = REST_CLUSTER_SUBMIT_CLASS childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class - childMainClass = "org.apache.spark.deploy.Client" + childMainClass = STANDALONE_CLUSTER_SUBMIT_CLASS if (args.supervise) { childArgs += "--supervise" } Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } @@ -663,7 +664,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { - childMainClass = "org.apache.spark.deploy.yarn.Client" + childMainClass = YARN_CLUSTER_SUBMIT_CLASS if (args.isPython) { childArgs += ("--primary-py-file", args.primaryResource) childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") @@ -684,7 +685,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (isMesosCluster) { assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") - childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childMainClass = REST_CLUSTER_SUBMIT_CLASS if (args.isPython) { // Second argument is main class childArgs += (args.primaryResource, "") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 21cb94142b15b..742a95841a138 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -32,6 +32,7 @@ import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonProcessingException import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SparkException} +import org.apache.spark.deploy.SparkApplication import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -398,9 +399,20 @@ private[spark] object RestSubmissionClient { val PROTOCOL_VERSION = "v1" /** - * Submit an application, assuming Spark parameters are specified through the given config. - * This is abstracted to its own method for testing purposes. + * Filter non-spark environment variables from any environment. */ + private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { + env.filterKeys { k => + // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || + k.startsWith("MESOS_") + } + } +} + +private[spark] class RestSubmissionClientApp extends SparkApplication { + + /** Submits a request to run the application and return the response. Visible for testing. */ def run( appResource: String, mainClass: String, @@ -417,7 +429,7 @@ private[spark] object RestSubmissionClient { client.createSubmission(submitRequest) } - def main(args: Array[String]): Unit = { + override def start(args: Array[String], conf: SparkConf): Unit = { if (args.length < 2) { sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) @@ -425,19 +437,8 @@ private[spark] object RestSubmissionClient { val appResource = args(0) val mainClass = args(1) val appArgs = args.slice(2, args.length) - val conf = new SparkConf - val env = filterSystemEnvironment(sys.env) + val env = RestSubmissionClient.filterSystemEnvironment(sys.env) run(appResource, mainClass, appArgs, conf, env) } - /** - * Filter non-spark environment variables from any environment. - */ - private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { - env.filterKeys { k => - // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || - k.startsWith("MESOS_") - } - } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index acefc9d2436d0..4c1f92a1bcbf2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -220,7 +220,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { if (driverConf.contains("spark.yarn.credentials.file")) { logInfo("Will periodically update credentials from: " + driverConf.get("spark.yarn.credentials.file")) - SparkHadoopUtil.get.startCredentialUpdater(driverConf) + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .getMethod("startCredentialUpdater", classOf[SparkConf]) + .invoke(null, driverConf) } cfg.hadoopDelegationCreds.foreach { tokens => @@ -236,7 +238,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - SparkHadoopUtil.get.stopCredentialUpdater() + if (driverConf.contains("spark.yarn.credentials.file")) { + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .getMethod("stopCredentialUpdater") + .invoke(null) + } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 930e09d90c2f5..51bf91614c866 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -50,6 +50,7 @@ import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ @@ -59,6 +60,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -2405,8 +2407,8 @@ private[spark] object Utils extends Logging { */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { val sparkValue = conf.get(key, default) - if (SparkHadoopUtil.get.isYarnMode) { - SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue) + if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) } else { sparkValue } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 9801b2638cc15..cf59265dd646d 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -18,7 +18,13 @@ package org.apache.spark import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} @@ -411,8 +417,12 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") + val mgr = new SecurityManager(conf) + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } intercept[IllegalArgumentException] { - new SecurityManager(conf) + mgr.initializeAuth() } } @@ -430,5 +440,24 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } + test("secret key generation in yarn mode") { + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, "yarn") + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } + } + ) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d0a34c5cdcf57..e200755e639e1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -235,7 +235,7 @@ class SparkSubmitSuite childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") - mainClass should be ("org.apache.spark.deploy.yarn.Client") + mainClass should be (SparkSubmit.YARN_CLUSTER_SUBMIT_CLASS) // In yarn cluster mode, also adding jars to classpath classpath(0) should endWith ("thejar.jar") @@ -323,11 +323,11 @@ class SparkSubmitSuite val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") + mainClass should be (SparkSubmit.REST_CLUSTER_SUBMIT_CLASS) } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" - mainClass should be ("org.apache.spark.deploy.Client") + mainClass should be (SparkSubmit.STANDALONE_CLUSTER_SUBMIT_CLASS) } classpath should have size 0 sys.props("SPARK_SUBMIT") should be ("true") @@ -402,7 +402,7 @@ class SparkSubmitSuite conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") - mainClass should be ("org.apache.spark.deploy.yarn.Client") + mainClass should be (SparkSubmit.YARN_CLUSTER_SUBMIT_CLASS) } test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 490baf040491f..e505bc018857d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -92,7 +92,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) + val response = new RestSubmissionClientApp().run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5b8dcd0338cce..9be01f617217b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-22372: Make cluster submission use SparkApplication. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getCurrentUserCredentials"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.addSecretKeyToUserCredentials"), + // SPARK-18085: Better History Server scalability for many / large applications ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ca0aa0ea3bc73..b2576b0d72633 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -56,11 +56,28 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. - private val sparkConf = new SparkConf() - private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) - .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null + private val sparkConf = new SparkConf() + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sparkConf.set(k, v) + } + } + + private val securityMgr = new SecurityManager(sparkConf) + + // Set system properties for each config entry. This covers two use cases: + // - The default configuration stored by the SparkHadoopUtil class + // - The user application creating a new SparkConf in cluster mode + // + // Both cases create a new SparkConf object which reads these configs from system properties. + sparkConf.getAll.foreach { case (k, v) => + sys.props(k) = v + } + + private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) + private val ugi = { val original = UserGroupInformation.getCurrentUser() @@ -311,7 +328,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val credentialManager = new YARNHadoopDelegationTokenManager( sparkConf, yarnConf, - conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf)) + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) val credentialRenewer = new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) @@ -323,13 +340,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends credentialRenewerThread.join() } - // Call this to force generation of secret so it gets populated into the Hadoop UGI. - val securityMgr = new SecurityManager(sparkConf) - if (isClusterMode) { - runDriver(securityMgr) + runDriver() } else { - runExecutorLauncher(securityMgr) + runExecutorLauncher() } } catch { case e: Exception => @@ -410,8 +424,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends _sparkConf: SparkConf, _rpcEnv: RpcEnv, driverRef: RpcEndpointRef, - uiAddress: Option[String], - securityMgr: SecurityManager) = { + uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = @@ -463,7 +476,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends YarnSchedulerBackend.ENDPOINT_NAME) } - private def runDriver(securityMgr: SecurityManager): Unit = { + private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -479,7 +492,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val driverRef = createSchedulerRef( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) registered = true } else { // Sanity check; should never happen in normal operation, since sc should only be null @@ -498,15 +511,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { + private def runExecutorLauncher(): Unit = { val hostname = Utils.localHostName val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) val driverRef = waitForSparkDriver() addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), - securityMgr) + registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) registered = true // In client mode the actor will stop the reporter thread. @@ -686,6 +698,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here } + val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) @@ -809,15 +822,6 @@ object ApplicationMaster extends Logging { def main(args: Array[String]): Unit = { SignalUtils.registerLogger(log) val amArgs = new ApplicationMasterArguments(args) - - // Load the properties file with the Spark configuration and set entries as system properties, - // so that user code run inside the AM also has access to them. - // Note: we must do this before SparkHadoopUtil instantiated - if (amArgs.propertiesFile != null) { - Utils.getPropertiesFromFile(amArgs.propertiesFile).foreach { case (k, v) => - sys.props(k) = v - } - } master = new ApplicationMaster(amArgs) System.exit(master.run()) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 99e7d46ca5c96..3781b261a0381 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -48,7 +48,7 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging @@ -58,18 +58,14 @@ import org.apache.spark.util.{CallerContext, Utils} private[spark] class Client( val args: ClientArguments, - val hadoopConf: Configuration, val sparkConf: SparkConf) extends Logging { import Client._ import YarnSparkHadoopUtil._ - def this(clientArgs: ClientArguments, spConf: SparkConf) = - this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) - private val yarnClient = YarnClient.createYarnClient - private val yarnConf = new YarnConfiguration(hadoopConf) + private val hadoopConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster" @@ -125,7 +121,7 @@ private[spark] class Client( private val credentialManager = new YARNHadoopDelegationTokenManager( sparkConf, hadoopConf, - conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf)) + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) @@ -134,8 +130,6 @@ private[spark] class Client( def stop(): Unit = { launcherBackend.close() yarnClient.stop() - // Unset YARN mode system env variable, to allow switching between cluster types. - System.clearProperty("SPARK_YARN_MODE") } /** @@ -152,7 +146,7 @@ private[spark] class Client( // Setup the credentials before doing anything else, // so we have don't have issues at any point. setupCredentials() - yarnClient.init(yarnConf) + yarnClient.init(hadoopConf) yarnClient.start() logInfo("Requesting a new application from cluster with %d NodeManagers" @@ -398,7 +392,7 @@ private[spark] class Client( if (SparkHadoopUtil.get.isProxyUser(currentUser)) { currentUser.addCredentials(credentials) } - logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) + logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } // If we use principal and keytab to login, also credentials can be renewed some time @@ -758,12 +752,14 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) - yarnConf.writeXml(confStream) + hadoopConf.writeXml(confStream) confStream.closeEntry() - // Save Spark configuration to a file in the archive. + // Save Spark configuration to a file in the archive, but filter out the app's secret. val props = new Properties() - sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + sparkConf.getAll.foreach { case (k, v) => + props.setProperty(k, v) + } // Override spark.yarn.key to point to the location in distributed cache which will be used // by AM. Option(amKeytabFileName).foreach { k => props.setProperty(KEYTAB.key, k) } @@ -786,8 +782,7 @@ private[spark] class Client( pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() - populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) - env("SPARK_YARN_MODE") = "true" + populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() if (loginFromKeytab) { @@ -861,6 +856,7 @@ private[spark] class Client( } else { Nil } + val launchEnv = setupLaunchEnv(appStagingDirPath, pySparkArchives) val localResources = prepareLocalResources(appStagingDirPath, pySparkArchives) @@ -991,7 +987,11 @@ private[spark] class Client( logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") - launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } + if (log.isDebugEnabled) { + Utils.redact(sparkConf, launchEnv.toSeq).foreach { case (k, v) => + logDebug(s" $k -> $v") + } + } logDebug(" resources:") localResources.foreach { case (k, v) => logDebug(s" $k -> $v")} logDebug(" command:") @@ -1185,24 +1185,6 @@ private[spark] class Client( private object Client extends Logging { - def main(argStrings: Array[String]) { - if (!sys.props.contains("SPARK_SUBMIT")) { - logWarning("WARNING: This client is deprecated and will be removed in a " + - "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") - } - - // Set an env variable indicating we are running in YARN mode. - // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes - System.setProperty("SPARK_YARN_MODE", "true") - val sparkConf = new SparkConf - // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, - // so remove them from sparkConf here for yarn mode. - sparkConf.remove("spark.jars") - sparkConf.remove("spark.files") - val args = new ClientArguments(argStrings) - new Client(args, sparkConf).run() - } - // Alias for the user jar val APP_JAR_NAME: String = "__app__.jar" @@ -1506,3 +1488,16 @@ private object Client extends Logging { } } + +private[spark] class YarnClusterApplication extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, + // so remove them from sparkConf here for yarn mode. + conf.remove("spark.jars") + conf.remove("spark.files") + + new Client(new ClientArguments(args), conf).run() + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 72f4d273ab53b..c1ae12aabb8cc 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -92,7 +92,7 @@ private[spark] class YarnRMClient extends Logging { /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { - YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId() + YarnSparkHadoopUtil.getContainerId.getApplicationAttemptId() } /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 9c1472cb50e3a..f406fabd61860 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,21 +17,14 @@ package org.apache.spark.deploy.yarn -import java.nio.charset.StandardCharsets.UTF_8 -import java.util.regex.Matcher -import java.util.regex.Pattern +import java.util.regex.{Matcher, Pattern} import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.{JobConf, Master} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -43,87 +36,10 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils - -/** - * Contains util methods to interact with Hadoop from spark. - */ -class YarnSparkHadoopUtil extends SparkHadoopUtil { +object YarnSparkHadoopUtil { private var credentialUpdater: CredentialUpdater = _ - override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - dest.addCredentials(source.getCredentials()) - } - - // Note that all params which start with SPARK are propagated all the way through, so if in yarn - // mode, this MUST be set to true. - override def isYarnMode(): Boolean = { true } - - // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop - // subsystems. Always create a new config, don't reuse yarnConf. - override def newConfiguration(conf: SparkConf): Configuration = { - val hadoopConf = new YarnConfiguration(super.newConfiguration(conf)) - hadoopConf.addResource(Client.SPARK_HADOOP_CONF_FILE) - hadoopConf - } - - // Add any user credentials to the job conf which are necessary for running on a secure Hadoop - // cluster - override def addCredentials(conf: JobConf) { - val jobCreds = conf.getCredentials() - jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) - } - - override def addSecretKeyToUserCredentials(key: String, secret: String) { - val creds = new Credentials() - creds.addSecretKey(new Text(key), secret.getBytes(UTF_8)) - addCurrentUserCredentials(creds) - } - - override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { - val credentials = getCurrentUserCredentials() - if (credentials != null) credentials.getSecretKey(new Text(key)) else null - } - - private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - private[spark] override def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - - private[spark] def getContainerId: ContainerId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - ConverterUtils.toContainerId(containerIdString) - } - - /** The filesystems for which YARN should fetch delegation tokens. */ - private[spark] def hadoopFSsToAccess( - sparkConf: SparkConf, - hadoopConf: Configuration): Set[FileSystem] = { - val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) - .map(new Path(_).getFileSystem(hadoopConf)) - .toSet - - val stagingFS = sparkConf.get(STAGING_DIR) - .map(new Path(_).getFileSystem(hadoopConf)) - .getOrElse(FileSystem.get(hadoopConf)) - - filesystemsToAccess + stagingFS - } -} - -object YarnSparkHadoopUtil { // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -137,14 +53,6 @@ object YarnSparkHadoopUtil { // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) - def get: YarnSparkHadoopUtil = { - val yarnMode = java.lang.Boolean.parseBoolean( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if (!yarnMode) { - throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!") - } - SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil] - } /** * Add a path variable to the given environment map. * If the map already contains this key, append the value to the existing value instead. @@ -277,5 +185,42 @@ object YarnSparkHadoopUtil { securityMgr.getModifyAclsGroups) ) } -} + def getContainerId: ContainerId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + ConverterUtils.toContainerId(containerIdString) + } + + /** The filesystems for which YARN should fetch delegation tokens. */ + def hadoopFSsToAccess( + sparkConf: SparkConf, + hadoopConf: Configuration): Set[FileSystem] = { + val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) + .map(new Path(_).getFileSystem(hadoopConf)) + .toSet + + val stagingFS = sparkConf.get(STAGING_DIR) + .map(new Path(_).getFileSystem(hadoopConf)) + .getOrElse(FileSystem.get(hadoopConf)) + + filesystemsToAccess + stagingFS + } + + def startCredentialUpdater(sparkConf: SparkConf): Unit = { + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) + credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) + credentialUpdater.start() + } + + def stopCredentialUpdater(): Unit = { + if (credentialUpdater != null) { + credentialUpdater.stop() + credentialUpdater = null + } + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 6134757a82fdc..eaf2cff111a49 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -62,7 +62,7 @@ private[yarn] class AMCredentialRenewer( private val credentialRenewerThread: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = YarnSparkHadoopUtil.get + private val hadoopUtil = SparkHadoopUtil.get private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index b722cc401bb73..0c6206eebe41d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -66,7 +66,7 @@ private[spark] class YarnClientSchedulerBackend( // reads the credentials from HDFS, just like the executors and updates its own credentials // cache. if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.get.startCredentialUpdater(conf) + YarnSparkHadoopUtil.startCredentialUpdater(conf) } monitorThread = asyncMonitorApplication() monitorThread.start() @@ -153,7 +153,7 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.get.stopCredentialUpdater() + YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index e2d477be329c3..62bf9818ee248 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -41,7 +41,7 @@ private[spark] class YarnClusterSchedulerBackend( var driverLogs: Option[Map[String, String]] = None try { val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) - val containerId = YarnSparkHadoopUtil.get.getContainerId + val containerId = YarnSparkHadoopUtil.getContainerId val httpAddress = System.getenv(Environment.NM_HOST.name()) + ":" + System.getenv(Environment.NM_HTTP_PORT.name()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 9c3b18e4ec5f3..ac67f2196e0a0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -62,18 +62,14 @@ abstract class BaseYarnClusterSuite protected var hadoopConfDir: File = _ private var logConfDir: File = _ - var oldSystemProperties: Properties = null - def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() - oldSystemProperties = SerializationUtils.clone(System.getProperties) tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8) @@ -124,7 +120,6 @@ abstract class BaseYarnClusterSuite try { yarnCluster.stop() } finally { - System.setProperties(oldSystemProperties) super.afterAll() } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 6cf68427921fd..9d5f5eb621118 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -24,7 +24,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap => MutableHashMap} -import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -36,34 +35,18 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils} import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} +import org.apache.spark.util.{SparkConfWithEnv, Utils} -class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll - with ResetSystemProperties { +class ClientSuite extends SparkFunSuite with Matchers { import Client._ var oldSystemProperties: Properties = null - override def beforeAll(): Unit = { - super.beforeAll() - oldSystemProperties = SerializationUtils.clone(System.getProperties) - System.setProperty("SPARK_YARN_MODE", "true") - } - - override def afterAll(): Unit = { - try { - System.setProperties(oldSystemProperties) - oldSystemProperties = null - } finally { - super.afterAll() - } - } - test("default Yarn application classpath") { getDefaultYarnApplicationClasspath should be(Fixtures.knownDefYarnAppCP) } @@ -185,7 +168,6 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll } test("configuration and args propagate through createApplicationSubmissionContext") { - val conf = new Configuration() // When parsing tags, duplicates and leading/trailing whitespace should be removed. // Spaces between non-comma strings should be preserved as single tags. Empty strings may or // may not be removed depending on the version of Hadoop being used. @@ -200,7 +182,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) - val client = new Client(args, conf, sparkConf) + val client = new Client(args, sparkConf) client.createApplicationSubmissionContext( new YarnClientApplication(getNewApplicationResponse, appContext), containerLaunchContext) @@ -407,15 +389,14 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll private def createClient( sparkConf: SparkConf, - conf: Configuration = new Configuration(), args: Array[String] = Array()): Client = { val clientArgs = new ClientArguments(args) - spy(new Client(clientArgs, conf, sparkConf)) + spy(new Client(clientArgs, sparkConf)) } private def classpath(client: Client): Array[String] = { val env = new MutableHashMap[String, String]() - populateClasspath(null, client.hadoopConf, client.sparkConf, env) + populateClasspath(null, new Configuration(), client.sparkConf, env) classpath(env) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d5de19072ce29..ab0005d7b53a8 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -115,8 +115,13 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("run Spark in yarn-cluster mode with using SparkHadoopUtil.conf") { - testYarnAppUseSparkHadoopUtilConf() + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(false, + mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), + appArgs = Seq("key=value", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value")) + checkResult(finalState, result) } test("run Spark in yarn-client mode with additional jar") { @@ -216,15 +221,6 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, result) } - private def testYarnAppUseSparkHadoopUtilConf(): Unit = { - val result = File.createTempFile("result", null, tempDir) - val finalState = runSpark(false, - mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) - checkResult(finalState, result) - } - private def testWithAddJar(clientMode: Boolean): Unit = { val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) val driverResult = File.createTempFile("driver", null, tempDir) @@ -424,7 +420,7 @@ private object YarnClusterDriver extends Logging with Matchers { s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " ) } - val containerId = YarnSparkHadoopUtil.get.getContainerId + val containerId = YarnSparkHadoopUtil.getContainerId val user = Utils.getCurrentUserName() assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index a057618b39950..f21353aa007c8 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -71,14 +71,10 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging test("Yarn configuration override") { val key = "yarn.nodemanager.hostname" - val default = new YarnConfiguration() - val sparkConf = new SparkConf() .set("spark.hadoop." + key, "someHostName") - val yarnConf = new YarnSparkHadoopUtil().newConfiguration(sparkConf) - - yarnConf.getClass() should be (classOf[YarnConfiguration]) - yarnConf.get(key) should not be default.get(key) + val yarnConf = new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(sparkConf)) + yarnConf.get(key) should be ("someHostName") } @@ -145,45 +141,4 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } - test("check different hadoop utils based on env variable") { - try { - System.setProperty("SPARK_YARN_MODE", "true") - assert(SparkHadoopUtil.get.getClass === classOf[YarnSparkHadoopUtil]) - System.setProperty("SPARK_YARN_MODE", "false") - assert(SparkHadoopUtil.get.getClass === classOf[SparkHadoopUtil]) - } finally { - System.clearProperty("SPARK_YARN_MODE") - } - } - - - - // This test needs to live here because it depends on isYarnMode returning true, which can only - // happen in the YARN module. - test("security manager token generation") { - try { - System.setProperty("SPARK_YARN_MODE", "true") - val initial = SparkHadoopUtil.get - .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) - assert(initial === null || initial.length === 0) - - val conf = new SparkConf() - .set(SecurityManager.SPARK_AUTH_CONF, "true") - .set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val sm = new SecurityManager(conf) - - val generated = SparkHadoopUtil.get - .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) - assert(generated != null) - val genString = new Text(generated).toString() - assert(genString != "unused") - assert(sm.getSecretKey() === genString) - } finally { - // removeSecretKey() was only added in Hadoop 2.6, so instead we just set the secret - // to an empty string. - SparkHadoopUtil.get.addSecretKeyToUserCredentials(SecurityManager.SECRET_LOOKUP_KEY, "") - System.clearProperty("SPARK_YARN_MODE") - } - } - } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index c918998bde07c..3c7cdc0f1dab8 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -31,24 +31,15 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers override def beforeAll(): Unit = { super.beforeAll() - - System.setProperty("SPARK_YARN_MODE", "true") - sparkConf = new SparkConf() hadoopConf = new Configuration() } - override def afterAll(): Unit = { - super.afterAll() - - System.clearProperty("SPARK_YARN_MODE") - } - test("Correctly loads credential providers") { credentialManager = new YARNHadoopDelegationTokenManager( sparkConf, hadoopConf, - conf => YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, conf)) + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) credentialManager.credentialProviders.get("yarn-test") should not be (None) } From 1d5597b408485e41812f3645a670864ad88570a0 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 4 Dec 2017 15:08:07 -0800 Subject: [PATCH 342/712] [SPARK-22626][SQL][FOLLOWUP] improve documentation and simplify test case ## What changes were proposed in this pull request? This PR improves documentation for not using zero `numRows` statistics and simplifies the test case. The reason why some Hive tables have zero `numRows` is that, in Hive, when stats gathering is disabled, `numRows` is always zero after INSERT command: ``` hive> create table src (key int, value string) stored as orc; hive> desc formatted src; Table Parameters: COLUMN_STATS_ACCURATE {\"BASIC_STATS\":\"true\"} numFiles 0 numRows 0 rawDataSize 0 totalSize 0 transient_lastDdlTime 1512399590 hive> set hive.stats.autogather=false; hive> insert into src select 1, 'a'; hive> desc formatted src; Table Parameters: numFiles 1 numRows 0 rawDataSize 0 totalSize 275 transient_lastDdlTime 1512399647 hive> insert into src select 1, 'b'; hive> desc formatted src; Table Parameters: numFiles 2 numRows 0 rawDataSize 0 totalSize 550 transient_lastDdlTime 1512399687 ``` ## How was this patch tested? Modified existing test. Author: Zhenhua Wang Closes #19880 from wzhfy/doc_zero_rowCount. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 8 +++++--- .../org/apache/spark/sql/hive/StatisticsSuite.scala | 11 +++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 77e836003b39f..08eb5c74f06d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -426,9 +426,11 @@ private[hive] class HiveClientImpl( // TODO: stats should include all the other two fields (`numFiles` and `numPartitions`). // (see StatsSetupConst in Hive) val stats = - // When table is external, `totalSize` is always zero, which will influence join strategy - // so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, - // return None. Later, we will use the other ways to estimate the statistics. + // When table is external, `totalSize` is always zero, which will influence join strategy. + // So when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, + // return None. + // In Hive, when statistics gathering is disabled, `rawDataSize` and `numRows` is always + // zero after INSERT command. So they are used here only if they are larger than zero. if (totalSize.isDefined && totalSize.get > 0L) { Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0))) } else if (rawDataSize.isDefined && rawDataSize.get > 0) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ee027e5308265..13f06a2503656 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1366,17 +1366,16 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql("CREATE TABLE maybe_big (c1 bigint)" + "TBLPROPERTIES ('numRows'='0', 'rawDataSize'='60000000000', 'totalSize'='8000000000000')") - val relation = spark.table("maybe_big").queryExecution.analyzed.children.head - .asInstanceOf[HiveTableRelation] + val catalogTable = getCatalogTable("maybe_big") - val properties = relation.tableMeta.ignoredProperties + val properties = catalogTable.ignoredProperties assert(properties("totalSize").toLong > 0) assert(properties("rawDataSize").toLong > 0) assert(properties("numRows").toLong == 0) - assert(relation.stats.sizeInBytes > 0) - // May be cause OOM if rowCount == 0 when enables CBO, see SPARK-22626 for details. - assert(relation.stats.rowCount.isEmpty) + val catalogStats = catalogTable.stats.get + assert(catalogStats.sizeInBytes > 0) + assert(catalogStats.rowCount.isEmpty) } } } From 3887b7eef7b89d3aeecadebc0fdafa47586a232b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 4 Dec 2017 17:08:56 -0800 Subject: [PATCH 343/712] [SPARK-22665][SQL] Avoid repartitioning with empty list of expressions ## What changes were proposed in this pull request? Repartitioning by empty set of expressions is currently possible, even though it is a case which is not handled properly. Indeed, in `HashExpression` there is a check to avoid to run it on an empty set, but this check is not performed while repartitioning. Thus, the PR adds a check to avoid this wrong situation. ## How was this patch tested? added UT Author: Marco Gaido Closes #19870 from mgaido91/SPARK-22665. --- .../plans/logical/basicLogicalOperators.scala | 12 +++++++----- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 5 ++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 93de7c1daf5c2..ba5f97d608feb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, + RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -847,14 +847,16 @@ case class RepartitionByExpression( "`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " + "means `HashPartitioning`. In this case we have:" + s""" - |SortOrder: ${sortOrder} - |NonSortOrder: ${nonSortOrder} + |SortOrder: $sortOrder + |NonSortOrder: $nonSortOrder """.stripMargin) if (sortOrder.nonEmpty) { RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions) - } else { + } else if (nonSortOrder.nonEmpty) { HashPartitioning(nonSortOrder, numPartitions) + } else { + RoundRobinPartitioning(numPartitions) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0e2e706a31a05..109fb32aa4a12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, + RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ @@ -530,6 +531,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkPartitioning[RangePartitioning](numPartitions = 10, exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending)) + checkPartitioning[RoundRobinPartitioning](numPartitions = 10, exprs = Seq.empty: _*) + intercept[IllegalArgumentException] { checkPartitioning(numPartitions = 0, exprs = Literal(20)) } From 295df746ecb1def5530a044d6670b28821da89f0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Dec 2017 12:38:26 +0800 Subject: [PATCH 344/712] [SPARK-22677][SQL] cleanup whole stage codegen for hash aggregate ## What changes were proposed in this pull request? The `HashAggregateExec` whole stage codegen path is a little messy and hard to understand, this code cleans it up a little bit, especially for the fast hash map part. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19869 from cloud-fan/hash-agg. --- .../aggregate/HashAggregateExec.scala | 402 +++++++++--------- 1 file changed, 195 insertions(+), 207 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 913978892cd8c..26d8cd7278353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.MutableColumnarRow +import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -444,6 +444,7 @@ case class HashAggregateExec( val funcName = ctx.freshName("doAggregateWithKeysOutput") val keyTerm = ctx.freshName("keyTerm") val bufferTerm = ctx.freshName("bufferTerm") + val numOutput = metricTerm(ctx, "numOutputRows") val body = if (modes.contains(Final) || modes.contains(Complete)) { @@ -520,6 +521,7 @@ case class HashAggregateExec( s""" private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) throws java.io.IOException { + $numOutput.add(1); $body } """) @@ -549,7 +551,7 @@ case class HashAggregateExec( isSupported && isNotByteArrayDecimalType } - private def enableTwoLevelHashMap(ctx: CodegenContext) = { + private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" @@ -560,9 +562,8 @@ case class HashAggregateExec( // This is for testing/benchmarking only. // We enforce to first level to be a vectorized hashmap, instead of the default row-based one. - sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { - case "true" => isVectorizedHashMapEnabled = true - case null | "" | "false" => None } + isVectorizedHashMapEnabled = sqlContext.getConf( + "spark.sql.codegen.aggregate.map.vectorized.enable", "false") == "true" } } @@ -573,94 +574,84 @@ case class HashAggregateExec( enableTwoLevelHashMap(ctx) } else { sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { - case "true" => logWarning("Two level hashmap is disabled but vectorized hashmap is " + - "enabled.") - case null | "" | "false" => None + case "true" => + logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") + case _ => } } - fastHashMapTerm = ctx.freshName("fastHashMap") - val fastHashMapClassName = ctx.freshName("FastHashMap") - val fastHashMapGenerator = - if (isVectorizedHashMapEnabled) { - new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema) - } else { - new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema) - } val thisPlan = ctx.addReferenceObj("plan", this) - // Create a name for iterator from vectorized HashMap + // Create a name for the iterator from the fast hash map. val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") if (isFastHashMapEnabled) { + // Generates the fast hash map class and creates the fash hash map term. + fastHashMapTerm = ctx.freshName("fastHashMap") + val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { + val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + ctx.addInnerClass(generatedMap) + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( - "java.util.Iterator", + s"java.util.Iterator<${classOf[ColumnarRow].getName}>", iterTermForFastHashMap) } else { + val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + ctx.addInnerClass(generatedMap) + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( - "org.apache.spark.unsafe.KVIterator", + "org.apache.spark.unsafe.KVIterator", iterTermForFastHashMap) } } + // Create a name for the iterator from the regular hash map. + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm) + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) - - def generateGenerateCode(): String = { - if (isFastHashMapEnabled) { - if (isVectorizedHashMapEnabled) { - s""" - | ${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()} - """.stripMargin - } else { - s""" - | ${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()} - """.stripMargin - } - } else "" - } - ctx.addInnerClass(generateGenerateCode()) - val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") val avgHashProbe = metricTerm(ctx, "avgHashProbe") - val doAggFuncName = ctx.addNewFunction(doAgg, - s""" - private void $doAgg() throws java.io.IOException { - $hashMapTerm = $thisPlan.createHashMap(); - ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - ${if (isFastHashMapEnabled) { - s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} + val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);" + val finishHashMap = if (isFastHashMapEnabled) { + s""" + |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); + |$finishRegularHashMap + """.stripMargin + } else { + finishRegularHashMap + } - $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashProbe); - } - """) + val doAggFuncName = ctx.addNewFunction(doAgg, + s""" + |private void $doAgg() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $finishHashMap + |} + """.stripMargin) // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") val outputFunc = generateResultFunction(ctx) - val numOutput = metricTerm(ctx, "numOutputRows") - def outputFromGeneratedMap: String = { + def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { outputFromVectorizedMap @@ -672,48 +663,56 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - while ($iterTermForFastHashMap.next()) { - $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); - $outputFunc($keyTerm, $bufferTerm); - - if (shouldStop()) return; - } - $fastHashMapTerm.close(); - """ + |while ($iterTermForFastHashMap.next()) { + | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); + | $outputFunc($keyTerm, $bufferTerm); + | + | if (shouldStop()) return; + |} + |$fastHashMapTerm.close(); + """.stripMargin } // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow def outputFromVectorizedMap: String = { - val row = ctx.freshName("fastHashMapRow") - ctx.currentVars = null - ctx.INPUT_ROW = row - val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, - groupingKeySchema.toAttributes.zipWithIndex + val row = ctx.freshName("fastHashMapRow") + ctx.currentVars = null + ctx.INPUT_ROW = row + val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, + groupingKeySchema.toAttributes.zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } - ) - val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, - bufferSchema.toAttributes.zipWithIndex - .map { case (attr, i) => - BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) - s""" - | while ($iterTermForFastHashMap.hasNext()) { - | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarRow $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarRow) - | $iterTermForFastHashMap.next(); - | ${generateKeyRow.code} - | ${generateBufferRow.code} - | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); - | - | if (shouldStop()) return; - | } - | - | $fastHashMapTerm.close(); - """.stripMargin + ) + val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, + bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => + BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) + }) + val columnarRowCls = classOf[ColumnarRow].getName + s""" + |while ($iterTermForFastHashMap.hasNext()) { + | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next(); + | ${generateKeyRow.code} + | ${generateBufferRow.code} + | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); + | + | if (shouldStop()) return; + |} + | + |$fastHashMapTerm.close(); + """.stripMargin } + def outputFromRegularHashMap: String = { + s""" + |while ($iterTerm.next()) { + | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + | $outputFunc($keyTerm, $bufferTerm); + | + | if (shouldStop()) return; + |} + """.stripMargin + } val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") @@ -726,16 +725,8 @@ case class HashAggregateExec( } // output the result - ${outputFromGeneratedMap} - - while ($iterTerm.next()) { - $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); - $outputFunc($keyTerm, $bufferTerm); - - if (shouldStop()) return; - } + $outputFromFastHashMap + $outputFromRegularHashMap $iterTerm.close(); if ($sorterTerm == null) { @@ -745,13 +736,11 @@ case class HashAggregateExec( } private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // create grouping key - ctx.currentVars = input val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val fastRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -768,12 +757,8 @@ case class HashAggregateExec( // generate hash code for key val hashExpr = Murmur3Hash(groupingExpressions, 42) - ctx.currentVars = input val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) - val inputAttr = aggregateBufferAttributes ++ child.output - ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") @@ -784,86 +769,65 @@ case class HashAggregateExec( ("true", "true", "", "") } - // We first generate code to probe and update the fast hash map. If the probe is - // successful the corresponding fast row buffer will hold the mutable row - val findOrInsertFastHashMap: Option[String] = { + val findOrInsertRegularHashMap: String = + s""" + |// generate grouping key + |${unsafeRowKeyCode.code.trim} + |${hashEval.code.trim} + |if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + |} + |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + |// aggregation after processing all input rows. + |if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, ${hashEval.value}); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + |} + """.stripMargin + + val findOrInsertHashMap: String = { if (isFastHashMapEnabled) { - Option( - s""" - | - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } - |} - """.stripMargin) + // If fast hash map is on, we first generate code to probe and update the fast hash map. + // If the probe is successful the corresponding fast row buffer will hold the mutable row. + s""" + |if ($checkFallbackForGeneratedHashMap) { + | ${fastRowKeys.map(_.code).mkString("\n")} + | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); + | } + |} + |// Cannot find the key in fast hash map, try regular hash map. + |if ($fastRowBuffer == null) { + | $findOrInsertRegularHashMap + |} + """.stripMargin } else { - None + findOrInsertRegularHashMap } } + val inputAttr = aggregateBufferAttributes ++ child.output + // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when + // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while + // generating input columns, we use `currentVars`. + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { - ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized) - } - Option( - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(fastRowEvals)} - |// update fast row - |${updateFastRow.mkString("\n").trim} - | - """.stripMargin) - } - - // Next, we generate code to probe and update the unsafe row hash map. - val findOrInsertInUnsafeRowMap: String = { - s""" - | if ($fastRowBuffer == null) { - | // generate grouping key - | ${unsafeRowKeyCode.code.trim} - | ${hashEval.code.trim} - | if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); - | } - | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - | // aggregation after processing all input rows. - | if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); - | } - | } - | } - """.stripMargin - } - - val updateRowInUnsafeRowMap: String = { + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) @@ -882,45 +846,69 @@ case class HashAggregateExec( |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin + """.stripMargin + } + + val updateRowInHashMap: String = { + if (isFastHashMapEnabled) { + ctx.INPUT_ROW = fastRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + } + + // If fast hash map is on, we first generate code to update row in fast hash map, if the + // previous loop up hit fast hash map. Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvals)} + | // update fast row + | ${updateFastRow.mkString("\n").trim} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + } else { + updateRowInRegularHashMap + } } + val declareRowBuffer: String = if (isFastHashMapEnabled) { + val fastRowType = if (isVectorizedHashMapEnabled) { + classOf[MutableColumnarRow].getName + } else { + "UnsafeRow" + } + s""" + |UnsafeRow $unsafeRowBuffer = null; + |$fastRowType $fastRowBuffer = null; + """.stripMargin + } else { + s"UnsafeRow $unsafeRowBuffer = null;" + } // We try to do hash map based in-memory aggregation first. If there is not enough memory (the // hash map will return null for new key), we spill the hash map to disk to free memory, then // continue to do in-memory aggregation and spilling until all the rows had been processed. // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" - UnsafeRow $unsafeRowBuffer = null; - ${ - if (isVectorizedHashMapEnabled) { - s""" - | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null; - """.stripMargin - } else { - s""" - | UnsafeRow $fastRowBuffer = null; - """.stripMargin - } - } + $declareRowBuffer - ${findOrInsertFastHashMap.getOrElse("")} - - $findOrInsertInUnsafeRowMap + $findOrInsertHashMap $incCounter - if ($fastRowBuffer != null) { - // update fast row - ${ - if (isFastHashMapEnabled) { - updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("") - } else "" - } - } else { - // update unsafe row - $updateRowInUnsafeRowMap - } + $updateRowInHashMap """ } From a8af4da12ce43cd5528a53b5f7f454e9dbe71d6e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Dec 2017 12:43:05 +0800 Subject: [PATCH 345/712] [SPARK-22682][SQL] HashExpression does not need to create global variables ## What changes were proposed in this pull request? It turns out that `HashExpression` can pass around some values via parameter when splitting codes into methods, to save some global variable slots. This can also prevent a weird case that global variable appears in parameter list, which is discovered by https://github.com/apache/spark/pull/19865 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19878 from cloud-fan/minor. --- .../spark/sql/catalyst/expressions/hash.scala | 118 +++++++++++++----- .../expressions/HashExpressionsSuite.scala | 34 +++-- 2 files changed, 106 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index c3289b8299933..d0ed2ab8f3f0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" - val childrenHash = ctx.splitExpressions(children.map { child => + + val childrenHash = children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } - }) + } + + val hashResultType = ctx.javaType(dataType) + val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + childrenHash.mkString("\n") + } else { + ctx.splitExpressions( + expressions = childrenHash, + funcName = "computeHash", + arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + } - ctx.addMutableState(ctx.javaType(dataType), ev.value) - ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + ev.copy(code = + s""" + |$hashResultType ${ev.value} = $seed; + |$codes + """.stripMargin) } protected def nullSafeElementHash( @@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { - val hashes = fields.zipWithIndex.map { case (field, index) => + val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } + val hashResultType = ctx.javaType(dataType) ctx.splitExpressions( - expressions = hashes, - funcName = "getHash", - arguments = Seq("InternalRow" -> input)) + expressions = fieldsHash, + funcName = "computeHashForStruct", + arguments = Seq("InternalRow" -> input, hashResultType -> result), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return $result; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) } @tailrec @@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" + val childHash = ctx.freshName("childHash") - val childrenHash = ctx.splitExpressions(children.map { child => + val childrenHash = children.map { child => val childGen = child.genCode(ctx) val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) } s""" |${childGen.code} + |$childHash = 0; |$codeToComputeHash |${ev.value} = (31 * ${ev.value}) + $childHash; - |$childHash = 0; """.stripMargin - }) + } - ctx.addMutableState(ctx.javaType(dataType), ev.value) - ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;") - ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + childrenHash.mkString("\n") + } else { + ctx.splitExpressions( + expressions = childrenHash, + funcName = "computeHash", + arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childHash = 0; + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + } + + ev.copy(code = + s""" + |${ctx.JAVA_INT} ${ev.value} = $seed; + |${ctx.JAVA_INT} $childHash = 0; + |$codes + """.stripMargin) } override def eval(input: InternalRow = null): Int = { @@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { - val localResult = ctx.freshName("localResult") val childResult = ctx.freshName("childResult") - fields.zipWithIndex.map { case (field, index) => + val fieldsHash = fields.zipWithIndex.map { case (field, index) => + val computeFieldHash = nullSafeElementHash( + input, index.toString, field.nullable, field.dataType, childResult, ctx) s""" - $childResult = 0; - ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, - childResult, ctx)} - $localResult = (31 * $localResult) + $childResult; - """ - }.mkString( - s""" - int $localResult = 0; - int $childResult = 0; - """, - "", - s"$result = (31 * $result) + $localResult;" - ) + |$childResult = 0; + |$computeFieldHash + |$result = (31 * $result) + $childResult; + """.stripMargin + } + + s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + expressions = fieldsHash, + funcName = "computeHashForStruct", + arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childResult = 0; + |$body + |return $result; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 112a4a09728ae..4281c89ac475d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-18207: Compute hash for a lot of expressions") { + def checkResult(schema: StructType, input: InternalRow): Unit = { + val exprs = schema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val murmur3HashExpr = Murmur3Hash(exprs, 42) + val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) + val murmursHashEval = Murmur3Hash(exprs, 42).eval(input) + assert(murmur3HashPlan(input).getInt(0) == murmursHashEval) + + val hiveHashExpr = HiveHash(exprs) + val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr)) + val hiveHashEval = HiveHash(exprs).eval(input) + assert(hiveHashPlan(input).getInt(0) == hiveHashEval) + } + val N = 1000 val wideRow = new GenericInternalRow( Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any]) - val schema = StructType((1 to N).map(i => StructField("", StringType))) - - val exprs = schema.fields.zipWithIndex.map { case (f, i) => - BoundReference(i, f.dataType, true) - } - val murmur3HashExpr = Murmur3Hash(exprs, 42) - val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) - val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow) - assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval) + val schema = StructType((1 to N).map(i => StructField(i.toString, StringType))) + checkResult(schema, wideRow) - val hiveHashExpr = HiveHash(exprs) - val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr)) - val hiveHashEval = HiveHash(exprs).eval(wideRow) - assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval) + val nestedRow = InternalRow(wideRow) + val nestedSchema = new StructType().add("nested", schema) + checkResult(nestedSchema, nestedRow) } test("SPARK-22284: Compute hash for nested structs") { From 53e5251bb36cfa36ffa9058887a9944f87f26879 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Dec 2017 20:43:02 +0800 Subject: [PATCH 346/712] [SPARK-22675][SQL] Refactoring PropagateTypes in TypeCoercion ## What changes were proposed in this pull request? PropagateTypes are called twice in TypeCoercion. We do not need to call it twice. Instead, we should call it after each change on the types. ## How was this patch tested? The existing tests Author: gatorsmile Closes #19874 from gatorsmile/deduplicatePropagateTypes. --- .../catalyst/analysis/DecimalPrecision.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 116 ++++++++++-------- .../catalyst/expressions/Canonicalize.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../apache/spark/sql/TPCDSQuerySuite.scala | 3 + .../execution/HiveCompatibilitySuite.scala | 2 +- 6 files changed, 71 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index fd2ac78b25dbd..070bc542e4852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -58,7 +58,7 @@ import org.apache.spark.sql.types._ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE */ // scalastyle:on -object DecimalPrecision extends Rule[LogicalPlan] { +object DecimalPrecision extends TypeCoercionRule { import scala.math.{max, min} private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 28be955e08a0d..1ee2f6e941045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -22,6 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -45,8 +46,7 @@ import org.apache.spark.sql.types._ object TypeCoercion { val typeCoercionRules = - PropagateTypes :: - InConversion :: + InConversion :: WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: @@ -56,7 +56,6 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - PropagateTypes :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -220,38 +219,6 @@ object TypeCoercion { private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 - /** - * Applies any changes to [[AttributeReference]] data types that are made by other rules to - * instances higher in the query tree. - */ - object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - - // No propagation required for leaf nodes. - case q: LogicalPlan if q.children.isEmpty => q - - // Don't propagate types from unresolved children. - case q: LogicalPlan if !q.childrenResolved => q - - case q: LogicalPlan => - val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap - q transformExpressions { - case a: AttributeReference => - inputMap.get(a.exprId) match { - // This can happen when an Attribute reference is born in a non-leaf node, for - // example due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}") - newType - } - } - } - } - /** * Widens numeric types and converts strings to numbers when appropriate. * @@ -345,7 +312,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends Rule[LogicalPlan] { + object PromoteStrings extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -354,7 +321,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -403,7 +370,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends Rule[LogicalPlan] { + object InConversion extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -413,7 +380,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,8 +479,8 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -602,8 +569,8 @@ object TypeCoercion { * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. */ - object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object Division extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -624,8 +591,8 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object CaseWhenCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -654,8 +621,8 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object IfCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -674,8 +641,8 @@ object TypeCoercion { /** * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ - object StackCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + object StackCoercion extends TypeCoercionRule { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -711,8 +678,8 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object ImplicitTypeCasts extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -828,8 +795,8 @@ object TypeCoercion { /** * Cast WindowFrame boundaries to the type they operate upon. */ - object WindowFrameCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object WindowFrameCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -850,3 +817,46 @@ object TypeCoercion { } } } + +trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { + /** + * Applies any changes to [[AttributeReference]] data types that are made by the transform method + * to instances higher in the query tree. + */ + def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = coerceTypes(plan) + if (plan.fastEquals(newPlan)) { + plan + } else { + propagateTypes(newPlan) + } + } + + protected def coerceTypes(plan: LogicalPlan): LogicalPlan + + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + // No propagation required for leaf nodes. + case q: LogicalPlan if q.children.isEmpty => q + + // Don't propagate types from unresolved children. + case q: LogicalPlan if !q.childrenResolved => q + + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when an Attribute reference is born in a non-leaf node, for + // example due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug( + s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}") + newType + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 65e497afc12cd..d848ba18356d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -31,7 +31,7 @@ package org.apache.spark.sql.catalyst.expressions * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. */ -object Canonicalize extends { +object Canonicalize { def execute(e: Expression): Expression = { expressionReorder(ignoreNamesTypes(e)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af4b9988..cac3e12dde3e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -300,7 +300,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { Locale.setDefault(originalLocale) // For debugging dump some statistics about how much time was spent in various optimizer rules - logWarning(RuleExecutor.dumpTimeSpent()) + logInfo(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index e47d4b0ee25d4..dbea03689785c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -39,6 +40,8 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte */ protected override def afterAll(): Unit = { try { + // For debugging dump some statistics about how much time was spent in various optimizer rules + logInfo(RuleExecutor.dumpTimeSpent()) spark.sessionState.catalog.reset() } finally { super.afterAll() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c69b4cb7..def70a516a96b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules - logWarning(RuleExecutor.dumpTimeSpent()) + logInfo(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } From 326f1d6728a7734c228d8bfaa69442a1c7b92e9b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 5 Dec 2017 20:46:35 +0800 Subject: [PATCH 347/712] [SPARK-20728][SQL] Make OrcFileFormat configurable between sql/hive and sql/core ## What changes were proposed in this pull request? This PR aims to provide a configuration to choose the default `OrcFileFormat` from legacy `sql/hive` module or new `sql/core` module. For example, this configuration will affects the following operations. ```scala spark.read.orc(...) ``` ```sql CREATE TABLE t USING ORC ... ``` ## How was this patch tested? Pass the Jenkins with new test suites. Author: Dongjoon Hyun Closes #19871 from dongjoon-hyun/spark-sql-orc-enabled. --- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../spark/sql/execution/command/tables.scala | 6 +- .../execution/datasources/DataSource.scala | 23 +++++- .../sql/execution/datasources/rules.scala | 5 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 23 +++++- .../sql/sources/DDLSourceLoadSuite.scala | 7 -- .../sql/test/DataFrameReaderWriterSuite.scala | 79 +++++++++++-------- .../spark/sql/hive/HiveStrategies.scala | 16 +++- .../spark/sql/hive/orc/OrcQuerySuite.scala | 19 ++++- 12 files changed, 136 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8abb4262d7352..ce9cc562b220f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -363,6 +363,14 @@ object SQLConf { .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) .createWithDefault("snappy") + val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") + .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + + "1.2.1. It is 'hive' by default prior to Spark 2.3.") + .internal() + .stringConf + .checkValues(Set("hive", "native")) + .createWithDefault("native") + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0c5f3f22e31e8..6cdfe2fae5642 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,6 +1,7 @@ org.apache.spark.sql.execution.datasources.csv.CSVFileFormat org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat +org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 17966eecfc051..ea1cf66775235 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -182,7 +182,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { "read files of Hive data source directly.") } - val cls = DataSource.lookupDataSource(source) + val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val options = new DataSourceV2Options(extraOptions.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 35abeccfd514a..59a01e61124f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -234,7 +234,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - val cls = DataSource.lookupDataSource(source) + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { cls.newInstance() match { case ds: WriteSupport => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index c42e6c3257fad..e400975f19708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -190,7 +191,7 @@ case class AlterTableAddColumnsCommand( colsToAdd: Seq[StructField]) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - val catalogTable = verifyAlterTableAddColumn(catalog, table) + val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table) try { sparkSession.catalog.uncacheTable(table.quotedString) @@ -216,6 +217,7 @@ case class AlterTableAddColumnsCommand( * For datasource table, it currently only supports parquet, json, csv. */ private def verifyAlterTableAddColumn( + conf: SQLConf, catalog: SessionCatalog, table: TableIdentifier): CatalogTable = { val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table) @@ -229,7 +231,7 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { + DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b43d282bd434c..5f12d5f93a35c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -36,8 +36,10 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} @@ -85,7 +87,8 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) - lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) + lazy val providingClass: Class[_] = + DataSource.lookupDataSource(className, sparkSession.sessionState.conf) lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = CaseInsensitiveMap(options) private val equality = sparkSession.sessionState.conf.resolver @@ -537,6 +540,7 @@ object DataSource extends Logging { val csv = classOf[CSVFileFormat].getCanonicalName val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" + val nativeOrc = classOf[OrcFileFormat].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -553,6 +557,8 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet, "org.apache.spark.sql.hive.orc.DefaultSource" -> orc, "org.apache.spark.sql.hive.orc" -> orc, + "org.apache.spark.sql.execution.datasources.orc.DefaultSource" -> nativeOrc, + "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv @@ -568,8 +574,16 @@ object DataSource extends Logging { "org.apache.spark.Logging") /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) + def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { + val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" => + classOf[OrcFileFormat].getCanonicalName + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => + "org.apache.spark.sql.hive.orc.OrcFileFormat" + case name => name + } val provider2 = s"$provider1.DefaultSource" val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) @@ -587,7 +601,8 @@ object DataSource extends Logging { if (provider1.toLowerCase(Locale.ROOT) == "orc" || provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( - "The ORC data source must be used with Hive support enabled") + "Hive-based ORC data source must be used with Hive support enabled. " + + "Please use native ORC data source instead") } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 60c430bcfece2..6e08df75b8a4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -108,8 +108,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi } // Check if the specified data source match the data source of the existing table. - val existingProvider = DataSource.lookupDataSource(existingTable.provider.get) - val specifiedProvider = DataSource.lookupDataSource(tableDesc.provider.get) + val conf = sparkSession.sessionState.conf + val existingProvider = DataSource.lookupDataSource(existingTable.provider.get, conf) + val specifiedProvider = DataSource.lookupDataSource(tableDesc.provider.get, conf) // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). if (existingProvider != specifiedProvider) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 31d9b909ad463..86bd9b95bca6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1664,7 +1666,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") } - assert(e.message.contains("The ORC data source must be used with Hive support enabled")) + assert(e.message.contains("Hive-based ORC data source must be used with Hive support")) e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") @@ -2782,4 +2784,23 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(spark.read.format(orc).load(path).collect().length == 2) } } + + test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { + val e = intercept[AnalysisException] { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + } + assert(e.message.contains("Hive-based ORC data source must be used with Hive support")) + } + + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass + } + assert(fileFormat == Some(classOf[OrcFileFormat])) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 3ce6ae3c52927..f22d843bfabde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -53,13 +53,6 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))) } - - test("should fail to load ORC without Hive Support") { - val e = intercept[AnalysisException] { - spark.read.format("orc").load() - } - assert(e.message.contains("The ORC data source must be used with Hive support enabled")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a5d7e6257a6df..8c9bb7d56a35f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -155,7 +155,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } - test("resolve default source") { spark.read .format("org.apache.spark.sql.test") @@ -478,42 +477,56 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ expData, userSchema) } - /** - * This only tests whether API compiles, but does not run it as orc() - * cannot be run without Hive classes. - */ - ignore("orc - API") { - // Reader, with user specified schema - // Refer to csv-specific test suites for behavior without user specified schema - spark.read.schema(userSchema).orc() - spark.read.schema(userSchema).orc(dir) - spark.read.schema(userSchema).orc(dir, dir, dir) - spark.read.schema(userSchema).orc(Seq(dir, dir): _*) - Option(dir).map(spark.read.schema(userSchema).orc) + test("orc - API and behavior regarding schema") { + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).orc(dir) + val df = spark.read.orc(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema - // Writer - spark.range(10).write.orc(dir) + // Reader, without user specified schema + intercept[AnalysisException] { + testRead(spark.read.orc(), Seq.empty, schema) + } + testRead(spark.read.orc(dir), data, schema) + testRead(spark.read.orc(dir, dir), data ++ data, schema) + testRead(spark.read.orc(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.orc).get, data, schema) + + // Reader, with user specified schema, data should be nulls as schema in file different + // from user schema + val expData = Seq[String](null, null, null) + testRead(spark.read.schema(userSchema).orc(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).orc(dir), expData, userSchema) + testRead(spark.read.schema(userSchema).orc(dir, dir), expData ++ expData, userSchema) + testRead( + spark.read.schema(userSchema).orc(Seq(dir, dir): _*), expData ++ expData, userSchema) + } } test("column nullability and comment - write and then read") { - Seq("json", "parquet", "csv").foreach { format => - val schema = StructType( - StructField("cl1", IntegerType, nullable = false).withComment("test") :: - StructField("cl2", IntegerType, nullable = true) :: - StructField("cl3", IntegerType, nullable = true) :: Nil) - val row = Row(3, null, 4) - val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) - - val tableName = "tab" - withTable(tableName) { - df.write.format(format).mode("overwrite").saveAsTable(tableName) - // Verify the DDL command result: DESCRIBE TABLE - checkAnswer( - sql(s"desc $tableName").select("col_name", "comment").where($"comment" === "test"), - Row("cl1", "test") :: Nil) - // Verify the schema - val expectedFields = schema.fields.map(f => f.copy(nullable = true)) - assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { + Seq("json", "orc", "parquet", "csv").foreach { format => + val schema = StructType( + StructField("cl1", IntegerType, nullable = false).withComment("test") :: + StructField("cl2", IntegerType, nullable = true) :: + StructField("cl3", IntegerType, nullable = true) :: Nil) + val row = Row(3, null, 4) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + val tableName = "tab" + withTable(tableName) { + df.write.format(format).mode("overwrite").saveAsTable(tableName) + // Verify the DDL command result: DESCRIBE TABLE + checkAnswer( + sql(s"desc $tableName").select("col_name", "comment").where($"comment" === "test"), + Row("cl1", "test") :: Nil) + // Verify the schema + val expectedFields = schema.fields.map(f => f.copy(nullable = true)) + assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ee1f6ee173063..3018c0642f062 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -195,8 +194,19 @@ case class RelationConversions( .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { val options = relation.tableMeta.storage.properties - sessionCatalog.metastoreCatalog - .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") + if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { + sessionCatalog.metastoreCatalog.convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], + "orc") + } else { + sessionCatalog.metastoreCatalog.convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], + "orc") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 1fa9091f967a3..1ffaf30311037 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -28,7 +28,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -621,4 +621,21 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { makeOrcFile((1 to 10).map(Tuple1.apply), path2) assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count()) } + + test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { + Seq( + ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]), + ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { case (i, format) => + + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> i) { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass + } + assert(fileFormat == Some(format)) + } + } + } + } } From 03fdc92e42d260a2b7c0090115f162ba5c091aae Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Tue, 5 Dec 2017 09:15:22 -0800 Subject: [PATCH 348/712] [SPARK-22681] Accumulator should only be updated once for each task in result stage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? As the doc says "For accumulator updates performed inside actions only, Spark guarantees that each task’s update to the accumulator will only be applied once, i.e. restarted tasks will not update the value." But currently the code doesn't guarantee this. ## How was this patch tested? New added tests. Author: Carson Wang Closes #19877 from carsonwang/fixAccum. --- .../apache/spark/scheduler/DAGScheduler.scala | 14 ++++++++++--- .../spark/scheduler/DAGSchedulerSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9153751d03c1b..c2498d4808e91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1187,9 +1187,17 @@ class DAGScheduler( // only updated in certain cases. event.reason match { case Success => - stage match { - case rs: ResultStage if rs.activeJob.isEmpty => - // Ignore update if task's job has finished. + task match { + case rt: ResultTask[_, _] => + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.activeJob match { + case Some(job) => + // Only update the accumulator once for each result task. + if (!job.finished(rt.outputId)) { + updateAccumulators(event) + } + case None => // Ignore update if task's job has finished. + } case _ => updateAccumulators(event) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index feefb6a4d73f0..d812b5bd92c1b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1832,6 +1832,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("accumulator not calculated for resubmitted task in result stage") { + val accum = AccumulatorSuite.createLongAccum("a") + val finalRdd = new MyRDD(sc, 2, Nil) + submit(finalRdd, Array(0, 1)) + // finish the first task + completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) + // verify stage exists + assert(scheduler.stageIdToStage.contains(0)) + + // finish the first task again (simulate a speculative task or a resubmitted task) + completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + + // The accumulator should only be updated once. + assert(accum.value === 1) + + runEvent(makeCompletionEvent(taskSets(0).tasks(1), Success, 42)) + assertDataStructuresEmpty() + } + test("accumulators are updated on exception failures") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") From ced6ccf0d6f362e299f270ed2a474f2e14f845da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Dec 2017 10:15:15 -0800 Subject: [PATCH 349/712] [SPARK-22701][SQL] add ctx.splitExpressionsWithCurrentInputs ## What changes were proposed in this pull request? This pattern appears many times in the codebase: ``` if (ctx.INPUT_ROW == null || ctx.currentVars != null) { exprs.mkString("\n") } else { ctx.splitExpressions(...) } ``` This PR adds a `ctx.splitExpressionsWithCurrentInputs` for this pattern ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19895 from cloud-fan/splitExpression. --- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 44 ++++----- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 2 +- .../expressions/complexTypeCreator.scala | 6 +- .../expressions/conditionalExpressions.scala | 84 ++++++++--------- .../sql/catalyst/expressions/generators.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 55 +++++------ .../expressions/nullExpressions.scala | 94 ++++++++----------- .../expressions/objects/objects.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 47 +++++----- .../expressions/stringExpressions.scala | 37 ++++---- 12 files changed, 179 insertions(+), 206 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d98f7b3d8efe6..739bd13c5078d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(evalChildren.map(updateEval)) + val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(evalChildren.map(updateEval)) + val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1645db12c53f0..670c82eff9286 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -781,29 +781,26 @@ class CodegenContext { * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * - * Note that we will extract the current inputs of this context and pass them to the generated - * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole - * stage codegen path. Whole stage codegen path is not supported yet. - * - * @param expressions the codes to evaluate expressions. - */ - def splitExpressions(expressions: Seq[String]): String = { - splitExpressions(expressions, funcName = "apply", extraArguments = Nil) - } - - /** - * Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name - * and extra arguments. + * Note that different from `splitExpressions`, we will extract the current inputs of this + * context and pass them to the generated functions. The input is `INPUT_ROW` for normal codegen + * path, and `currentVars` for whole stage codegen path. Whole stage codegen path is not + * supported yet. * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. - * @param extraArguments the list of (type, name) of the arguments of the split function - * except for ctx.INPUT_ROW - */ - def splitExpressions( + * @param extraArguments the list of (type, name) of the arguments of the split function, + * except for the current inputs like `ctx.INPUT_ROW`. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + */ + def splitExpressionsWithCurrentInputs( expressions: Seq[String], - funcName: String, - extraArguments: Seq[(String, String)]): String = { + funcName: String = "apply", + extraArguments: Seq[(String, String)] = Nil, + returnType: String = "void", + makeSplitFunction: String => String = identity, + foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { expressions.mkString("\n") @@ -811,13 +808,18 @@ class CodegenContext { splitExpressions( expressions, funcName, - arguments = ("InternalRow", INPUT_ROW) +: extraArguments) + ("InternalRow", INPUT_ROW) +: extraArguments, + returnType, + makeSplitFunction, + foldFunctions) } } /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class to which the function would be inlined would grow + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it + * instead, because classes have a constant pool limit of 65,536 named values. * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5fdbda51b4ad1..bd8312eb8b7fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(projectionCodes) - val allUpdates = ctx.splitExpressions(updates) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 5d35cce1a91cb..44e7148e5d98f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -159,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - val allExpressions = ctx.splitExpressions(expressionCodes) + val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index fc68bf478e1c8..087b21043b309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -108,7 +108,7 @@ private [sql] object GenArrayData { } """ } - val assignmentString = ctx.splitExpressions( + val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", extraArguments = ("Object[]", arrayDataName) :: Nil) @@ -139,7 +139,7 @@ private [sql] object GenArrayData { } """ } - val assignmentString = ctx.splitExpressions( + val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") - val valuesCode = ctx.splitExpressions( + val valuesCode = ctx.splitExpressionsWithCurrentInputs( valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 43e643178c899..ae5f7140847db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -219,57 +219,51 @@ case class CaseWhen( val allConditions = cases ++ elseCode - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - allConditions.mkString("\n") - } else { - // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet) { - // continue; - // } - // conditionMet = caseWhen_2(i); - // if(conditionMet) { - // continue; - // } - // ... - // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; - // do { - // // here the evaluation of the conditions - // } while (false); - // return conditionMet; - // } - ctx.splitExpressions(allConditions, "caseWhen", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - returnType = ctx.JAVA_BOOLEAN, - makeSplitFunction = { - func => - s""" - ${ctx.JAVA_BOOLEAN} $conditionMet = false; - do { - $func - } while (false); - return $conditionMet; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - $conditionMet = $funcCall; - if ($conditionMet) { - continue; - }""" - }.mkString - }) - } + // This generates code like: + // conditionMet = caseWhen_1(i); + // if(conditionMet) { + // continue; + // } + // conditionMet = caseWhen_2(i); + // if(conditionMet) { + // continue; + // } + // ... + // and the declared methods are: + // private boolean caseWhen_1234() { + // boolean conditionMet = false; + // do { + // // here the evaluation of the conditions + // } while (false); + // return conditionMet; + // } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = allConditions, + funcName = "caseWhen", + returnType = ctx.JAVA_BOOLEAN, + makeSplitFunction = func => + s""" + |${ctx.JAVA_BOOLEAN} $conditionMet = false; + |do { + | $func + |} while (false); + |return $conditionMet; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$conditionMet = $funcCall; + |if ($conditionMet) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; ${ctx.JAVA_BOOLEAN} $conditionMet = false; do { - $code + $codes } while (false);""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f1aa130669266..cd38783a731ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) - val code = ctx.splitExpressions(Seq.tabulate(numRows) { row => + val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => val fields = Seq.tabulate(numFields) { col => val index = row * numFields + col if (index < values.length) values(index) else Literal(null, dataTypes(col)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index d0ed2ab8f3f0e..055ebf6c0da54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -279,21 +279,17 @@ abstract class HashExpression[E] extends Expression { } val hashResultType = ctx.javaType(dataType) - val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - childrenHash.mkString("\n") - } else { - ctx.splitExpressions( - expressions = childrenHash, - funcName = "computeHash", - arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value), - returnType = hashResultType, - makeSplitFunction = body => - s""" - |$body - |return ${ev.value}; - """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenHash, + funcName = "computeHash", + extraArguments = Seq(hashResultType -> ev.value), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" @@ -652,22 +648,19 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - childrenHash.mkString("\n") - } else { - ctx.splitExpressions( - expressions = childrenHash, - funcName = "computeHash", - arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, - makeSplitFunction = body => - s""" - |${ctx.JAVA_INT} $childHash = 0; - |$body - |return ${ev.value}; - """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenHash, + funcName = "computeHash", + extraArguments = Seq(ctx.JAVA_INT -> ev.value), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childHash = 0; + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 3b52a0efd404a..26c9a41efc9f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -87,37 +87,32 @@ case class Coalesce(children: Seq[Expression]) extends Expression { |} """.stripMargin } - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions(evals, "coalesce", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - makeSplitFunction = { - func => - s""" - |do { - | $func - |} while (false); - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - |$funcCall; - |if (!${ev.isNull}) { - | continue; - |} - """.stripMargin - }.mkString - }) - } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "coalesce", + makeSplitFunction = func => + s""" + |do { + | $func + |} while (false); + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }.mkString) + ev.copy(code = s""" |${ev.isNull} = true; |${ev.value} = ${ctx.defaultValue(dataType)}; |do { - | $code + | $codes |} while (false); """.stripMargin) } @@ -415,39 +410,32 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } } - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions( - expressions = evals, - funcName = "atLeastNNonNulls", - arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, - makeSplitFunction = { body => - s""" - |do { - | $body - |} while (false); - |return $nonnull; - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => - s""" - |$nonnull = $funcCall; - |if ($nonnull >= $n) { - | continue; - |} - """.stripMargin).mkString("\n") - } - ) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "atLeastNNonNulls", + extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |do { + | $body + |} while (false); + |return $nonnull; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy(code = s""" |${ctx.JAVA_INT} $nonnull = 0; |do { - | $code + | $codes |} while (false); |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e2bc79d98b33d..730b2ff96da6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression { """ } } - val argCode = ctx.splitExpressions(argCodes) + val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes) (argCode, argValues.mkString(", "), resultIsNull) } @@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """ } - val childrenCode = ctx.splitExpressions(childrenCodes) + val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" @@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ${javaBeanInstance}.$setterMethod(${fieldGen.value}); """ } - val initializeCode = ctx.splitExpressions(initialize.toSeq) + val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq) val code = s""" ${instanceGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 75cc9b3bd8045..04e669492ec6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -253,31 +253,26 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { | continue; |} """.stripMargin) - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - listCode.mkString("\n") - } else { - ctx.splitExpressions( - expressions = listCode, - funcName = "valueIn", - arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil, - makeSplitFunction = { body => - s""" - |do { - | $body - |} while (false); - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => - s""" - |$funcCall; - |if (${ev.value}) { - | continue; - |} - """.stripMargin).mkString("\n") - } - ) - } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = listCode, + funcName = "valueIn", + extraArguments = (javaDataType, valueArg) :: Nil, + makeSplitFunction = body => + s""" + |do { + | $body + |} while (false); + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$funcCall; + |if (${ev.value}) { + | continue; + |} + """.stripMargin + }.mkString("\n")) + ev.copy(code = s""" |${valueGen.code} @@ -286,7 +281,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |if (!${ev.isNull}) { | $javaDataType $valueArg = ${valueGen.value}; | do { - | $code + | $codes | } while (false); |} """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 34917ace001fa..47f0b5741f67f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -73,7 +73,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - val codes = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", extraArguments = ("UTF8String[]", args) :: Nil) @@ -152,7 +152,7 @@ case class ConcatWs(children: Seq[Expression]) "" } } - val codes = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) @@ -200,31 +200,32 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressions(evals.map(_.code)) - val varargCounts = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + + val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, funcName = "varargCountsConcatWs", - arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil, returnType = "int", makeSplitFunction = body => s""" - int $varargNum = 0; - $body - return $varargNum; - """, - foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) - val varargBuilds = ctx.splitExpressions( + |int $varargNum = 0; + |$body + |return $varargNum; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$varargNum += $funcCall;").mkString("\n")) + + val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - arguments = - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" - $body - return $idxInVararg; - """, - foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + |$body + |return $idxInVararg; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + ev.copy( s""" $codes @@ -1380,7 +1381,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $argList[$index] = $value; """ } - val argListCodes = ctx.splitExpressions( + val argListCodes = ctx.splitExpressionsWithCurrentInputs( expressions = argListCode, funcName = "valueFormatString", extraArguments = ("Object[]", argList) :: Nil) From 132a3f470811bb98f265d0c9ad2c161698e0237b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Dec 2017 11:40:13 -0800 Subject: [PATCH 350/712] [SPARK-22500][SQL][FOLLOWUP] cast for struct can split code even with whole stage codegen ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/19730, we can split the code for casting struct even with whole stage codegen. This PR also has some renaming to make the code easier to read. ## How was this patch tested? existing test Author: Wenchen Fan Closes #19891 from cloud-fan/cast. --- .../spark/sql/catalyst/expressions/Cast.scala | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f4ecbdb8393ad..b8d3661a00abc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } - // three function arguments are: child.primitive, result.primitive and result.isNull - // it returns the code snippets to be put in null safe evaluation region + // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` + // in parameter list, because the returned code will be put in null safe evaluation region. private[this] type CastFunction = (String, String, String) => String private[this] def nullSafeCastFunction( @@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String throw new SparkException(s"Cannot cast $from to $to.") } - // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, - resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, + result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" - boolean $resultNull = $childNull; - ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; - if (!$childNull) { - ${cast(childPrim, resultPrim, resultNull)} + boolean $resultIsNull = $inputIsNull; + ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + if (!$inputIsNull) { + ${cast(input, result, resultIsNull)} } """ } @@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } val rowClass = classOf[GenericInternalRow].getName - val result = ctx.freshName("result") - val tmpRow = ctx.freshName("tmpRow") + val tmpResult = ctx.freshName("tmpResult") + val tmpInput = ctx.freshName("tmpInput") val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") @@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val toFieldNull = ctx.freshName("tfn") val fromType = ctx.javaType(from.fields(i).dataType) s""" - boolean $fromFieldNull = $tmpRow.isNullAt($i); + boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; + ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ } - val fieldsEvalCodes = if (ctx.currentVars == null) { - ctx.splitExpressions( - expressions = fieldsEvalCode, - funcName = "castStruct", - arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil) - } else { - fieldsEvalCode.mkString("\n") - } + val fieldsEvalCodes = ctx.splitExpressions( + expressions = fieldsEvalCode, + funcName = "castStruct", + arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) - (c, evPrim, evNull) => + (input, result, resultIsNull) => s""" - final $rowClass $result = new $rowClass(${fieldsCasts.length}); - final InternalRow $tmpRow = $c; + final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); + final InternalRow $tmpInput = $input; $fieldsEvalCodes - $evPrim = $result; + $result = $tmpResult; """ } From 1e17ab83de29bca1823a537d7c57ffc4de8a26ee Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 5 Dec 2017 15:15:32 -0800 Subject: [PATCH 351/712] [SPARK-22662][SQL] Failed to prune columns after rewriting predicate subquery ## What changes were proposed in this pull request? As a simple example: ``` spark-sql> create table base (a int, b int) using parquet; Time taken: 0.066 seconds spark-sql> create table relInSubq ( x int, y int, z int) using parquet; Time taken: 0.042 seconds spark-sql> explain select a from base where a in (select x from relInSubq); == Physical Plan == *Project [a#83] +- *BroadcastHashJoin [a#83], [x#85], LeftSemi, BuildRight :- *FileScan parquet default.base[a#83,b#84] Batched: true, Format: Parquet, Location: InMemoryFileIndex[hdfs://100.0.0.4:9000/wzh/base], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))) +- *Project [x#85] +- *FileScan parquet default.relinsubq[x#85] Batched: true, Format: Parquet, Location: InMemoryFileIndex[hdfs://100.0.0.4:9000/wzh/relinsubq], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` We only need column `a` in table `base`, but all columns (`a`, `b`) are fetched. The reason is that, in "Operator Optimizations" batch, `ColumnPruning` first produces a `Project` on table `base`, but then it's removed by `removeProjectBeforeFilter`. Because at that time, the predicate subquery is in filter form. Then, in "Rewrite Subquery" batch, `RewritePredicateSubquery` converts the subquery into a LeftSemi join, but this batch doesn't have the `ColumnPruning` rule. This results in reading all columns for the `base` table. ## How was this patch tested? Added a new test case. Author: Zhenhua Wang Closes #19855 from wzhfy/column_pruning_subquery. --- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/RewriteSubquerySuite.scala | 55 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8a5c486912abf..484cd8c2475f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -141,7 +141,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CheckCartesianProducts) :: Batch("RewriteSubquery", Once, RewritePredicateSubquery, - CollapseProject) :: Nil + ColumnPruning, + CollapseProject, + RemoveRedundantProject) :: Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala new file mode 100644 index 0000000000000..6b3739c372c3a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.ListQuery +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class RewriteSubquerySuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Column Pruning", FixedPoint(100), ColumnPruning) :: + Batch("Rewrite Subquery", FixedPoint(1), + RewritePredicateSubquery, + ColumnPruning, + CollapseProject, + RemoveRedundantProject) :: Nil + } + + test("Column pruning after rewriting predicate subquery") { + val relation = LocalRelation('a.int, 'b.int) + val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int) + + val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = relation + .select('a) + .join(relInSubquery.select('x), LeftSemi, Some('a === 'x)) + .analyze + + comparePlans(optimized, correctAnswer) + } + +} From 59aa3d56af956267261b135936651bbbd41b1bfc Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Tue, 5 Dec 2017 18:08:36 -0600 Subject: [PATCH 352/712] [SPARK-20706][SPARK-SHELL] Spark-shell not overriding method/variable definition ## What changes were proposed in this pull request? [SPARK-20706](https://issues.apache.org/jira/browse/SPARK-20706): Spark-shell not overriding method/variable definition This is a Scala repl bug ( [SI-9740](https://github.com/scala/bug/issues/9740) ), was fixed in version 2.11.9 ( [see the original PR](https://github.com/scala/scala/pull/5090) ) ## How was this patch tested? Added a new test case in `ReplSuite`. Author: Mark Petruska Closes #19879 from mpetruska/SPARK-20706. --- .../spark/repl/SparkILoopInterpreter.scala | 138 +++++++++++++++++- .../org/apache/spark/repl/ReplSuite.scala | 39 +++-- 2 files changed, 167 insertions(+), 10 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index 0803426403af5..e736607a9a6b9 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -17,6 +17,7 @@ package org.apache.spark.repl +import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ @@ -30,7 +31,7 @@ class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain override def chooseHandler(member: intp.global.Tree): MemberHandler = member match { case member: Import => new SparkImportHandler(member) - case _ => super.chooseHandler (member) + case _ => super.chooseHandler(member) } class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) { @@ -100,4 +101,139 @@ class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain override def typeOfExpression(expr: String, silent: Boolean): global.Type = expressionTyper.typeOfExpression(expr, silent) + + import global.Name + override def importsCode(wanted: Set[Name], wrapper: Request#Wrapper, + definesClass: Boolean, generousImports: Boolean): ComputedImports = { + + import global._ + import definitions.{ ObjectClass, ScalaPackage, JavaLangPackage, PredefModule } + import memberHandlers._ + + val header, code, trailingBraces, accessPath = new StringBuilder + val currentImps = mutable.HashSet[Name]() + // only emit predef import header if name not resolved in history, loosely + var predefEscapes = false + + /** + * Narrow down the list of requests from which imports + * should be taken. Removes requests which cannot contribute + * useful imports for the specified set of wanted names. + */ + case class ReqAndHandler(req: Request, handler: MemberHandler) + + def reqsToUse: List[ReqAndHandler] = { + /** + * Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. + */ + def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { + // Single symbol imports might be implicits! See bug #1752. Rather than + // try to finesse this, we will mimic all imports for now. + def keepHandler(handler: MemberHandler) = handler match { + // While defining classes in class based mode - implicits are not needed. + case h: ImportHandler if isClassBased && definesClass => + h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x if generousImports => x.definesImplicit || + (x.definedNames exists (d => wanted.exists(w => d.startsWith(w)))) + case x => x.definesImplicit || + (x.definedNames exists wanted) + } + + reqs match { + case Nil => + predefEscapes = wanted contains PredefModule.name ; Nil + case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted) + case rh :: rest => + import rh.handler._ + val augment = rh match { + case ReqAndHandler(_, _: ImportHandler) => referencedNames + case _ => Nil + } + val newWanted = wanted ++ augment -- definedNames -- importedNames + rh :: select(rest, newWanted) + } + } + + /** Flatten the handlers out and pair each with the original request */ + select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse + } + + // add code for a new object to hold some imports + def addWrapper() { + import nme.{ INTERPRETER_IMPORT_WRAPPER => iw } + code append (wrapper.prewrap format iw) + trailingBraces append wrapper.postwrap + accessPath append s".$iw" + currentImps.clear() + } + + def maybeWrap(names: Name*) = if (names exists currentImps) addWrapper() + + def wrapBeforeAndAfter[T](op: => T): T = { + addWrapper() + try op finally addWrapper() + } + + // imports from Predef are relocated to the template header to allow hiding. + def checkHeader(h: ImportHandler) = h.referencedNames contains PredefModule.name + + // loop through previous requests, adding imports for each one + wrapBeforeAndAfter { + // Reusing a single temporary value when import from a line with multiple definitions. + val tempValLines = mutable.Set[Int]() + for (ReqAndHandler(req, handler) <- reqsToUse) { + val objName = req.lineRep.readPathInstance + handler match { + case h: ImportHandler if checkHeader(h) => + header.clear() + header append f"${h.member}%n" + // If the user entered an import, then just use it; add an import wrapping + // level if the import might conflict with some other import + case x: ImportHandler if x.importsWildcard => + wrapBeforeAndAfter(code append (x.member + "\n")) + case x: ImportHandler => + maybeWrap(x.importedNames: _*) + code append (x.member + "\n") + currentImps ++= x.importedNames + + case x if isClassBased => + for (sym <- x.definedSymbols) { + maybeWrap(sym.name) + x match { + case _: ClassHandler => + code.append(s"import ${objName}${req.accessPath}.`${sym.name}`\n") + case _ => + val valName = s"${req.lineRep.packageName}${req.lineRep.readName}" + if (!tempValLines.contains(req.lineRep.lineId)) { + code.append(s"val $valName: ${objName}.type = $objName\n") + tempValLines += req.lineRep.lineId + } + code.append(s"import ${valName}${req.accessPath}.`${sym.name}`\n") + } + currentImps += sym.name + } + // For other requests, import each defined name. + // import them explicitly instead of with _, so that + // ambiguity errors will not be generated. Also, quote + // the name of the variable, so that we don't need to + // handle quoting keywords separately. + case x => + for (sym <- x.definedSymbols) { + maybeWrap(sym.name) + code append s"import ${x.path}\n" + currentImps += sym.name + } + } + } + } + + val computedHeader = if (predefEscapes) header.toString else "" + ComputedImports(computedHeader, code.toString, trailingBraces.toString, accessPath.toString) + } + + private def allReqAndHandlers = + prevRequestList flatMap (req => req.handlers map (req -> _)) + } diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index a5053521f8e31..cdd5cdd841740 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -227,14 +227,35 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error: not found: value sc", output) } - test("spark-shell should find imported types in class constructors and extends clause") { - val output = runInterpreter("local", - """ - |import org.apache.spark.Partition - |class P(p: Partition) - |class P(val index: Int) extends Partition - """.stripMargin) - assertDoesNotContain("error: not found: type Partition", output) - } + test("spark-shell should find imported types in class constructors and extends clause") { + val output = runInterpreter("local", + """ + |import org.apache.spark.Partition + |class P(p: Partition) + |class P(val index: Int) extends Partition + """.stripMargin) + assertDoesNotContain("error: not found: type Partition", output) + } + + test("spark-shell should shadow val/def definitions correctly") { + val output1 = runInterpreter("local", + """ + |def myMethod() = "first definition" + |val tmp = myMethod(); val out = tmp + |def myMethod() = "second definition" + |val tmp = myMethod(); val out = s"$tmp aabbcc" + """.stripMargin) + assertContains("second definition aabbcc", output1) + + val output2 = runInterpreter("local", + """ + |val a = 1 + |val b = a; val c = b; + |val a = 2 + |val b = a; val c = b; + |s"!!$b!!" + """.stripMargin) + assertContains("!!2!!", output2) + } } From 82183f7b57f2a93e646c56a9e37fac64b348ff0b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Dec 2017 10:52:29 +0800 Subject: [PATCH 353/712] [SPARK-22686][SQL] DROP TABLE IF EXISTS should not show AnalysisException ## What changes were proposed in this pull request? During [SPARK-22488](https://github.com/apache/spark/pull/19713) to fix view resolution issue, there occurs a regression at `2.2.1` and `master` branch like the following. This PR fixes that. ```scala scala> spark.version res2: String = 2.2.1 scala> sql("DROP TABLE IF EXISTS t").show 17/12/04 21:01:06 WARN DropTableCommand: org.apache.spark.sql.AnalysisException: Table or view not found: t; org.apache.spark.sql.AnalysisException: Table or view not found: t; ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #19888 from dongjoon-hyun/SPARK-22686. --- .../spark/sql/execution/command/ddl.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 568567aa8ea88..0142f17ce62e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -203,14 +203,20 @@ case class DropTableCommand( case _ => } } - try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) - } catch { - case _: NoSuchTableException if ifExists => - case NonFatal(e) => log.warn(e.toString, e) + + if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + try { + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) + } + catalog.refreshTable(tableName) + catalog.dropTable(tableName, ifExists, purge) + } else if (ifExists) { + // no-op + } else { + throw new AnalysisException(s"Table or view not found: ${tableName.identifier}") } - catalog.refreshTable(tableName) - catalog.dropTable(tableName, ifExists, purge) Seq.empty[Row] } } From 00d176d2fe7bbdf55cb3146a9cb04ca99b1858b7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Dec 2017 21:43:41 -0800 Subject: [PATCH 354/712] [SPARK-20392][SQL] Set barrier to prevent re-entering a tree ## What changes were proposed in this pull request? The SQL `Analyzer` goes through a whole query plan even most part of it is analyzed. This increases the time spent on query analysis for long pipelines in ML, especially. This patch adds a logical node called `AnalysisBarrier` that wraps an analyzed logical plan to prevent it from analysis again. The barrier is applied to the analyzed logical plan in `Dataset`. It won't change the output of wrapped logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset will be put on the barrier, so only the new nodes created will be analyzed. This analysis barrier will be removed at the end of analysis stage. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #19873 from viirya/SPARK-20392-reopen. --- .../sql/catalyst/analysis/Analyzer.scala | 77 ++++++++++------ .../sql/catalyst/analysis/CheckAnalysis.scala | 4 - .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 32 ++++--- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 49 ---------- .../plans/logical/basicLogicalOperators.scala | 19 ++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++ .../sql/catalyst/plans/LogicalPlanSuite.scala | 42 ++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 91 ++++++++++--------- .../sql/execution/datasources/rules.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- 17 files changed, 185 insertions(+), 165 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e5c93b5f0e059..0d5e866c0683e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -165,14 +165,15 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + EliminateBarriers) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -200,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -242,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -611,7 +612,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -666,7 +667,9 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { + // Remove analysis barrier if any. + val right = EliminateBarriers(originalRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -709,7 +712,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + originalRight case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -722,7 +725,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - newRight + AnalysisBarrier(newRight) } } @@ -799,7 +802,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -993,7 +996,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1049,7 +1052,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1073,11 +1076,13 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => + val child = EliminateBarriers(originalChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1098,7 +1103,8 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => + val child = EliminateBarriers(originalChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1125,7 +1131,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return plan + return AnalysisBarrier(plan) } plan match { case p: Project => @@ -1197,7 +1203,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1334,7 +1340,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1350,7 +1356,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1373,7 +1379,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1399,7 +1405,9 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1459,6 +1467,8 @@ class Analyzer( case ae: AnalysisException => filter } + case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1571,7 +1581,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1629,7 +1639,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1946,7 +1956,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -1991,7 +2001,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2056,7 +2066,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2121,7 +2131,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2207,7 +2217,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2241,7 +2251,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2300,7 +2310,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2329,6 +2339,13 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** Remove the barrier nodes of analysis */ +object EliminateBarriers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case AnalysisBarrier(child) => child + } +} + /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. @@ -2379,7 +2396,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b5e8bdd79869e..6894aed15c16f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -78,8 +78,6 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { - case p if p.analyzed => // Skip already analyzed sub-plans - case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") @@ -353,8 +351,6 @@ trait CheckAnalysis extends PredicateHelper { case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } - - plan.foreach(_.setAnalyzed()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 070bc542e4852..a8100b9b24aac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -78,7 +78,7 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 7358f9ee36921..a214e59302cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 860d20f897690..f9fd0df9e4010 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1ee2f6e941045..2f306f58b7b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -247,9 +247,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if p.analyzed => p - + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -321,7 +319,8 @@ object TypeCoercion { } } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -380,7 +379,8 @@ object TypeCoercion { } } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -439,7 +439,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -480,7 +480,8 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -570,7 +571,8 @@ object TypeCoercion { * converted to fractional types. */ object Division extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -592,7 +594,8 @@ object TypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -622,7 +625,8 @@ object TypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -662,7 +666,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -679,7 +683,8 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -796,7 +801,8 @@ object TypeCoercion { * Cast WindowFrame boundaries to the type they operate upon. */ object WindowFrameCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index a27aa845bf0ae..af1f9165b0044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.resolveExpressions(transformTimeZoneExprs) + plan.transformAllExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index ea46dd7282401..3bbe41cf8f15e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 64b28565eb27c..2673bea648d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -270,7 +270,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 14188829db2af..a38458add7b5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -33,58 +33,9 @@ abstract class LogicalPlan with QueryPlanConstraints with Logging { - private var _analyzed: Boolean = false - - /** - * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. - */ - private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } - - /** - * Returns true if this node and its children have already been gone through analysis and - * verification. Note that this is only an optimization used to avoid analyzing trees that - * have already been analyzed, and can be reset by transformations. - */ - def analyzed: Boolean = _analyzed - /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already - * been marked as analyzed. - * - * @param rule the function use to transform this nodes children - */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) - } - } - } else { - this - } - } - - /** - * Recursively transforms the expressions of a tree, skipping nodes that have already - * been analyzed. - */ - def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this resolveOperators { - case p => p.transformExpressions(r) - } - } - override def verboseStringWithSuffix: String = { super.verboseString + statsCache.map(", " + _.toString).getOrElse("") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ba5f97d608feb..cd474551622d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -883,3 +883,22 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } + +/** + * A logical plan for setting a barrier of analysis. + * + * The SQL Analyzer goes through a whole query plan even most part of it is analyzed. This + * increases the time spent on query analysis for long pipelines in ML, especially. + * + * This logical plan wraps an analyzed logical plan to prevent it from analysis again. The barrier + * is applied to the analyzed logical plan in Dataset. It won't change the output of wrapped + * logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset + * will be put on the barrier, so only the new nodes created will be analyzed. + * + * This analysis barrier will be removed at the end of analysis stage. + */ +case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override def output: Seq[Attribute] = child.output + override def isStreaming: Boolean = child.isStreaming + override def doCanonicalize(): LogicalPlan = child.canonicalized +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 109fb32aa4a12..f4514205d3ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -543,4 +543,18 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, Ascending), 'b.attr) } } + + test("SPARK-20392: analysis barrier") { + // [[AnalysisBarrier]] will be removed after analysis + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl.a")), + AnalysisBarrier(SubqueryAlias("tbl", testRelation))), + Project(testRelation.output, SubqueryAlias("tbl", testRelation))) + + // Verify we won't go through a plan wrapped in a barrier. + // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. + val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), + SubqueryAlias("tbl", testRelation))) + assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index cdf912df7c76a..14041747fd20e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly - * skips sub-trees that have already been marked as analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier + * and make sure it can correctly skip sub-trees that have already been analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -36,39 +36,53 @@ class LogicalPlanSuite extends SparkFunSuite { private val testRelation = LocalRelation() - test("resolveOperator runs on operators") { + test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) - plan resolveOperators function + plan transformUp function assert(invocationCount === 1) + + invocationCount = 0 + plan transformDown function + assert(invocationCount === 1) } - test("resolveOperator runs on operators recursively") { + test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) - plan resolveOperators function + plan transformUp function assert(invocationCount === 2) + + invocationCount = 0 + plan transformDown function + assert(invocationCount === 2) } - test("resolveOperator skips all ready resolved plans") { + test("transformUp skips all ready resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan = Project(Nil, Project(Nil, testRelation)) - plan.foreach(_.setAnalyzed()) - plan resolveOperators function + val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) + plan transformUp function assert(invocationCount === 0) + + invocationCount = 0 + plan transformDown function + assert(invocationCount === 0) } - test("resolveOperator skips partially resolved plans") { + test("transformUp skips partially resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan1 = Project(Nil, testRelation) + val plan1 = AnalysisBarrier(Project(Nil, testRelation)) val plan2 = Project(Nil, plan1) - plan1.foreach(_.setAnalyzed()) - plan2 resolveOperators function + plan2 transformUp function assert(invocationCount === 1) + + invocationCount = 0 + plan2 transformDown function + assert(invocationCount === 1) } test("isStreaming") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 167c9d050c3c4..c34cf0a7a7718 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -191,6 +191,9 @@ class Dataset[T] private[sql]( } } + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. + @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -403,7 +406,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -604,7 +607,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) } /** @@ -777,7 +780,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) } /** @@ -855,7 +858,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -916,7 +919,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -925,8 +928,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -958,7 +961,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) } /** @@ -990,8 +993,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.logicalPlan, - other.logicalPlan, + this.planWithBarrier, + other.planWithBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1212,7 +1215,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, planWithBarrier) } /** @@ -1250,7 +1253,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), planWithBarrier) } /** @@ -1305,8 +1308,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, + planWithBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1324,8 +1327,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1401,7 +1404,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + Filter(condition.expr, planWithBarrier) } /** @@ -1578,7 +1581,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = planWithBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1724,7 +1727,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + Limit(Literal(n), planWithBarrier) } /** @@ -1774,7 +1777,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) } /** @@ -1833,7 +1836,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)) + CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) } /** @@ -1847,7 +1850,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(planWithBarrier, other.planWithBarrier) } /** @@ -1861,7 +1864,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(planWithBarrier, other.planWithBarrier) } /** @@ -1912,7 +1915,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, planWithBarrier) } } @@ -1954,15 +1957,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = logicalPlan.output + val sortOrder = planWithBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, logicalPlan) + Sort(sortOrder, global = false, planWithBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - logicalPlan + planWithBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -2046,7 +2049,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2087,7 +2090,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2235,7 +2238,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output + val attrs = this.planWithBarrier.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2283,7 +2286,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan) + Deduplicate(groupCols, planWithBarrier) } /** @@ -2465,7 +2468,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2479,7 +2482,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2493,7 +2496,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, planWithBarrier) } /** @@ -2508,7 +2511,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, planWithBarrier)) } /** @@ -2524,7 +2527,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, planWithBarrier), implicitly[Encoder[U]]) } @@ -2555,7 +2558,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) } /** @@ -2719,7 +2722,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + Repartition(numPartitions, shuffle = true, planWithBarrier) } /** @@ -2742,7 +2745,7 @@ class Dataset[T] private[sql]( |For range partitioning use repartitionByRange(...) instead. """.stripMargin) withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) } } @@ -2779,7 +2782,7 @@ class Dataset[T] private[sql]( case expr: Expression => SortOrder(expr, Ascending) }) withTypedPlan { - RepartitionByExpression(sortOrder, logicalPlan, numPartitions) + RepartitionByExpression(sortOrder, planWithBarrier, numPartitions) } } @@ -2817,7 +2820,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + Repartition(numPartitions, shuffle = false, planWithBarrier) } /** @@ -2900,7 +2903,7 @@ class Dataset[T] private[sql]( // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. @transient private lazy val rddQueryExecution: QueryExecution = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](planWithBarrier) sparkSession.sessionState.executePlan(deserialized) } @@ -3026,7 +3029,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = planWithBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -3226,7 +3229,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, logicalPlan) + Sort(sortOrder, global = global, planWithBarrier) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 6e08df75b8a4c..f64e079539c4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c25c90d0c70e2..b50642d275ba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext { test("collapse adjacent repartitions") { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length - assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3018c0642f062..a7961c757efa8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,7 +145,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) From fb6a9227516f893815fb0b6d26b578e21badd664 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Dec 2017 20:20:20 +0900 Subject: [PATCH 355/712] [SPARK-20728][SQL][FOLLOWUP] Use an actionable exception message ## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/19871 to improve an exception message. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #19903 from dongjoon-hyun/orc_exception. --- .../spark/sql/execution/datasources/DataSource.scala | 8 ++++---- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5f12d5f93a35c..b676672b38cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -598,11 +598,11 @@ object DataSource extends Logging { // Found the data source using fully qualified path dataSource case Failure(error) => - if (provider1.toLowerCase(Locale.ROOT) == "orc" || - provider1.startsWith("org.apache.spark.sql.hive.orc")) { + if (provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( - "Hive-based ORC data source must be used with Hive support enabled. " + - "Please use native ORC data source instead") + "Hive built-in ORC data source must be used with Hive support enabled. " + + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + + "'native'") } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 86bd9b95bca6a..8ddddbeee598f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1666,7 +1666,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") } - assert(e.message.contains("Hive-based ORC data source must be used with Hive support")) + assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") @@ -2790,7 +2790,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { sql("CREATE TABLE spark_20728(a INT) USING ORC") } - assert(e.message.contains("Hive-based ORC data source must be used with Hive support")) + assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) } withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { From 6f41c593bbefa946d13b62ecf4e85074fd3c1541 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 6 Dec 2017 08:27:17 -0800 Subject: [PATCH 356/712] [SPARK-22690][ML] Imputer inherit HasOutputCols ## What changes were proposed in this pull request? make `Imputer` inherit `HasOutputCols` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #19889 from zhengruifeng/using_HasOutputCols. --- .../org/apache/spark/ml/feature/Imputer.scala | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 4663f16b5f5dc..730ee9fc08db8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ /** * Params for [[Imputer]] and [[ImputerModel]]. */ -private[feature] trait ImputerParams extends Params with HasInputCols { +private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols { /** * The imputation strategy. Currently only "mean" and "median" are supported. @@ -63,16 +63,6 @@ private[feature] trait ImputerParams extends Params with HasInputCols { /** @group getParam */ def getMissingValue: Double = $(missingValue) - /** - * Param for output column names. - * @group param - */ - final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", - "output column names") - - /** @group getParam */ - final def getOutputCols: Array[String] = $(outputCols) - /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + From 813c0f945d7f03800975eaed26b86a1f30e513c9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Dec 2017 00:45:51 +0800 Subject: [PATCH 357/712] [SPARK-22704][SQL] Least and Greatest use less global variables ## What changes were proposed in this pull request? This PR accomplishes the following two items. 1. Reduce # of global variables from two to one 2. Make lifetime of global variable local within an operation Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures that an variable is not passed to arguments in a method split by `CodegenContext.splitExpressions()`, which is addressed by #19865. ## How was this patch tested? Added new test into `ArithmeticExpressionSuite` Author: Kazuaki Ishizaki Closes #19899 from kiszk/SPARK-22704. --- .../sql/catalyst/expressions/arithmetic.scala | 94 ++++++++++++------- .../ArithmeticExpressionSuite.scala | 11 +++ 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 739bd13c5078d..1893eec22b65d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("leastTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, ev.value, eval.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, ev.value, eval.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "least", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } @@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("greatestTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, eval.value, ev.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, eval.value, ev.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "greatest", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index fb759eba6a9e2..be638d80e45d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) } + + test("SPARK-22704: Least and greatest use less global variables") { + val ctx1 = new CodegenContext() + Least(Seq(Literal(1), Literal(1))).genCode(ctx1) + assert(ctx1.mutableStates.size == 1) + + val ctx2 = new CodegenContext() + Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) + assert(ctx2.mutableStates.size == 1) + } } From e98f9647f44d1071a6b070db070841b8cda6bd7a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 7 Dec 2017 00:50:49 +0800 Subject: [PATCH 358/712] [SPARK-22695][SQL] ScalaUDF should not use global variables ## What changes were proposed in this pull request? ScalaUDF is using global variables which are not needed. This can generate some unneeded entries in the constant pool. The PR replaces the unneeded global variables with local variables. ## How was this patch tested? added UT Author: Marco Gaido Author: Marco Gaido Closes #19900 from mgaido91/SPARK-22695. --- .../sql/catalyst/expressions/ScalaUDF.scala | 88 ++++++++++--------- .../catalyst/expressions/ScalaUDFSuite.scala | 6 ++ 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 179853032035e..4d26d9819321b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -982,35 +982,28 @@ case class ScalaUDF( // scalastyle:on line.size.limit - // Generate codes used to convert the arguments to Scala type for user-defined functions - private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName - val scalaUDFClassName = classOf[ScalaUDF].getName + private val converterClassName = classOf[Any => Any].getName + private val scalaUDFClassName = classOf[ScalaUDF].getName + private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + // Generate codes used to convert the arguments to Scala type for user-defined functions + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = { val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, - s"$converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + - s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + (converterTerm, + s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((Expression)((($scalaUDFClassName)" + + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") } override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { + val scalaUDF = ctx.freshName("scalaUDF") + val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName) - val scalaUDF = ctx.addReferenceObj("scalaUDF", this) - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - - // Generate codes used to convert the returned value of user-defined functions to Catalyst type + // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1022,8 +1015,6 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, - s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1033,34 +1024,45 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => - val eval = evals(i) - val argTerm = ctx.freshName("arg") - val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" - (convert, argTerm) + val (converters, funcArguments) = converterTerms.zipWithIndex.map { + case ((convName, convInit), i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = + s""" + |$convInit + |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value}); + """.stripMargin + (convert, argTerm) }.unzip val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" val callFunc = s""" - ${ctx.boxedType(dataType)} $resultTerm = null; - try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); - } catch (Exception e) { - throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); - } - """ + |${ctx.boxedType(dataType)} $resultTerm = null; + |$scalaUDFClassName $scalaUDF = $scalaUDFRef; + |try { + | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); + | $converterClassName $catalystConverterTerm = ($converterClassName) + | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType()); + | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + |} catch (Exception e) { + | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + |} + """.stripMargin - ev.copy(code = s""" - $evalCode - ${converters.mkString("\n")} - $callFunc - - boolean ${ev.isNull} = $resultTerm == null; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $resultTerm; - }""") + ev.copy(code = + s""" + |$evalCode + |${converters.mkString("\n")} + |$callFunc + | + |boolean ${ev.isNull} = $resultTerm == null; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + """.stripMargin) } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 13bd363c8b692..70dea4b39d55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.{IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e2.getMessage.contains("Failed to execute user defined function")) } + test("SPARK-22695: ScalaUDF should not use global variables") { + val ctx = new CodegenContext + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } From 4286cba7dacf4b457fff91da3743ac2518699945 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 6 Dec 2017 10:11:25 -0800 Subject: [PATCH 359/712] [SPARK-22710] ConfigBuilder.fallbackConf should trigger onCreate function ## What changes were proposed in this pull request? I was looking at the config code today and found that configs defined using ConfigBuilder.fallbackConf didn't trigger onCreate function. This patch fixes it. This doesn't require backporting since we currently have no configs that use it. ## How was this patch tested? Added a test case for all the config final creator functions in ConfigEntrySuite. Author: Reynold Xin Closes #19905 from rxin/SPARK-22710. --- .../spark/internal/config/ConfigBuilder.scala | 4 +++- .../internal/config/ConfigEntrySuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 8f4c1b60920db..b0cd7110a3b47 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -235,7 +235,9 @@ private[spark] case class ConfigBuilder(key: String) { } def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { - new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) + val entry = new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) + _onCreate.foreach(_(entry)) + entry } def regexConf: TypedConfigBuilder[Regex] = { diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index bf08276dbf971..02514dc7daef4 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -288,4 +288,24 @@ class ConfigEntrySuite extends SparkFunSuite { conf.remove(testKey("b")) assert(conf.get(iConf) === 3) } + + test("onCreate") { + var onCreateCalled = false + ConfigBuilder(testKey("oc1")).onCreate(_ => onCreateCalled = true).intConf.createWithDefault(1) + assert(onCreateCalled) + + onCreateCalled = false + ConfigBuilder(testKey("oc2")).onCreate(_ => onCreateCalled = true).intConf.createOptional + assert(onCreateCalled) + + onCreateCalled = false + ConfigBuilder(testKey("oc3")).onCreate(_ => onCreateCalled = true).intConf + .createWithDefaultString("1.0") + assert(onCreateCalled) + + val fallback = ConfigBuilder(testKey("oc4")).intConf.createWithDefault(1) + onCreateCalled = false + ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback) + assert(onCreateCalled) + } } From 51066b437b750933da03bb401c8356f4b48aefac Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 6 Dec 2017 10:39:15 -0800 Subject: [PATCH 360/712] [SPARK-14228][CORE][YARN] Lost executor of RPC disassociated, and occurs exception: Could not find CoarseGrainedScheduler or it has been stopped ## What changes were proposed in this pull request? I see the two instances where the exception is occurring. **Instance 1:** ``` 17/11/10 15:49:32 ERROR util.Utils: Uncaught exception in thread driver-revive-thread org.apache.spark.SparkException: Could not find CoarseGrainedScheduler. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:160) at org.apache.spark.rpc.netty.Dispatcher.postOneWayMessage(Dispatcher.scala:140) at org.apache.spark.rpc.netty.NettyRpcEnv.send(NettyRpcEnv.scala:187) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.send(NettyRpcEnv.scala:521) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anon$1$$anonfun$run$1$$anonfun$apply$mcV$sp$1.apply(CoarseGrainedSchedulerBackend.scala:125) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anon$1$$anonfun$run$1$$anonfun$apply$mcV$sp$1.apply(CoarseGrainedSchedulerBackend.scala:125) at scala.Option.foreach(Option.scala:257) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anon$1$$anonfun$run$1.apply$mcV$sp(CoarseGrainedSchedulerBackend.scala:125) at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1344) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anon$1.run(CoarseGrainedSchedulerBackend.scala:124) at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511) at java.util.concurrent.FutureTask.runAndReset(FutureTask.java:308) at java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.access$301(ScheduledThreadPoolExecutor.java:180) at java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:294) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` In CoarseGrainedSchedulerBackend.scala, driver-revive-thread starts with DriverEndpoint.onStart() and keeps sending the ReviveOffers messages periodically till it gets shutdown as part DriverEndpoint.onStop(). There is no proper coordination between the driver-revive-thread(shutdown) and the RpcEndpoint unregister, RpcEndpoint unregister happens first and then driver-revive-thread shuts down as part of DriverEndpoint.onStop(), In-between driver-revive-thread may try to send the ReviveOffers message which is leading to the above exception. To fix this issue, this PR moves the shutting down of driver-revive-thread to CoarseGrainedSchedulerBackend.stop() which executes before the DriverEndpoint unregister. **Instance 2:** ``` 17/11/10 16:31:38 ERROR cluster.YarnSchedulerBackend$YarnSchedulerEndpoint: Error requesting driver to remove executor 1 for reason Executor for container container_1508535467865_0226_01_000002 exited because of a YARN event (e.g., pre-emption) and not because of an error in the running job. org.apache.spark.SparkException: Could not find CoarseGrainedScheduler. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:160) at org.apache.spark.rpc.netty.Dispatcher.postLocalMessage(Dispatcher.scala:135) at org.apache.spark.rpc.netty.NettyRpcEnv.ask(NettyRpcEnv.scala:229) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.ask(NettyRpcEnv.scala:516) at org.apache.spark.rpc.RpcEndpointRef.ask(RpcEndpointRef.scala:63) at org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint$$anonfun$receive$1.applyOrElse(YarnSchedulerBackend.scala:269) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:117) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:205) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:101) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:221) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` Here YarnDriverEndpoint tries to send remove executor messages after the Yarn scheduler backend service stop, which is leading to the above exception. To avoid the above exception, 1) We may add a condition(which checks whether service has stopped or not) before sending executor remove message 2) Add a warn log message in onFailure case when the service is already stopped In this PR, chosen the 2) option which adds a log message in the case of onFailure without the exception stack trace since the option 1) would need to to go through for every remove executor message. ## How was this patch tested? I verified it manually, I don't see these exceptions with the PR changes. Author: Devaraj K Closes #19741 from devaraj-kavali/SPARK-14228. --- .../CoarseGrainedExecutorBackend.scala | 6 +--- .../CoarseGrainedSchedulerBackend.scala | 30 ++++++++----------- ...bernetesClusterSchedulerBackendSuite.scala | 6 ++-- .../cluster/YarnSchedulerBackend.scala | 20 ++++--------- 4 files changed, 22 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 4c1f92a1bcbf2..9b62e4b1b7150 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -165,11 +165,7 @@ private[spark] class CoarseGrainedExecutorBackend( } if (notifyDriver && driver.nonEmpty) { - driver.get.ask[Boolean]( - RemoveExecutor(executorId, new ExecutorLossReason(reason)) - ).failed.foreach(e => - logWarning(s"Unable to notify the driver due to " + e.getMessage, e) - )(ThreadUtils.sameThread) + driver.get.send(RemoveExecutor(executorId, new ExecutorLossReason(reason))) } System.exit(code) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7bfb4d53c1834..e6982f333d521 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -95,6 +95,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -103,9 +106,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp protected val addressToExecutorId = new HashMap[RpcAddress, String] - private val reviveThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") @@ -154,6 +154,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.values.foreach { ed => ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) } + + case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -215,14 +222,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } context.reply(true) - case RemoveExecutor(executorId, reason) => - // We will remove the executor's state and cannot restore it. However, the connection - // between the driver and the executor may be still alive so that the executor won't exit - // automatically, so try to tell the executor to stop itself. See SPARK-13519. - executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) - removeExecutor(executorId, reason) - context.reply(true) - case RemoveWorker(workerId, host, message) => removeWorker(workerId, host, message) context.reply(true) @@ -373,10 +372,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp shouldDisable } - - override def onStop() { - reviveThread.shutdownNow() - } } var driverEndpoint: RpcEndpointRef = null @@ -417,6 +412,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def stop() { + reviveThread.shutdownNow() stopExecutors() try { if (driverEndpoint != null) { @@ -465,9 +461,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * at once. */ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { - // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).failed.foreach(t => - logError(t.getMessage, t))(ThreadUtils.sameThread) + driverEndpoint.send(RemoveExecutor(executorId, reason)) } protected def removeWorker(workerId: String, host: String, message: String): Unit = { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 3febb2f47cfd4..13c09033a50ee 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -282,7 +282,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn // No more deletion attempts of the executors. // This is graceful termination and should not be detected as a failure. verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).ask[Boolean]( + verify(driverEndpointRef, times(1)).send( RemoveExecutor("1", ExecutorExited( 0, exitCausedByApp = false, @@ -318,7 +318,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn requestExecutorRunnable.getValue.run() allocatorRunnable.getAllValues.asScala.last.run() verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).ask[Boolean]( + verify(driverEndpointRef).send( RemoveExecutor("1", ExecutorExited( 1, exitCausedByApp = true, @@ -356,7 +356,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) allocatorRunnable.getValue.run() verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).ask[Boolean]( + verify(driverEndpointRef).send( RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 415a29fd887e8..bb615c36cd97f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.{AtomicBoolean} import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal @@ -245,14 +246,7 @@ private[spark] abstract class YarnSchedulerBackend( Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } - removeExecutorMessage - .flatMap { message => - driverEndpoint.ask[Boolean](message) - }(ThreadUtils.sameThread) - .onFailure { - case NonFatal(e) => logError( - s"Error requesting driver to remove executor $executorId after disconnection.", e) - }(ThreadUtils.sameThread) + removeExecutorMessage.foreach { message => driverEndpoint.send(message) } } override def receive: PartialFunction[Any, Unit] = { @@ -265,12 +259,10 @@ private[spark] abstract class YarnSchedulerBackend( addWebUIFilter(filterName, filterParams, proxyBase) case r @ RemoveExecutor(executorId, reason) => - logWarning(reason.toString) - driverEndpoint.ask[Boolean](r).onFailure { - case e => - logError("Error requesting driver to remove executor" + - s" $executorId for reason $reason", e) - }(ThreadUtils.sameThread) + if (!stopped.get) { + logWarning(s"Requesting driver to remove executor $executorId for reason $reason") + driverEndpoint.send(r) + } } From effca9868e3feae16c5722c36878b23e616d01a2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Dec 2017 13:11:38 -0800 Subject: [PATCH 361/712] [SPARK-22720][SS] Make EventTimeWatermark Extend UnaryNode ## What changes were proposed in this pull request? Our Analyzer and Optimizer have multiple rules for `UnaryNode`. After making `EventTimeWatermark` extend `UnaryNode`, we do not need a special handling for `EventTimeWatermark`. ## How was this patch tested? The existing tests Author: gatorsmile Closes #19913 from gatorsmile/eventtimewatermark. --- .../spark/sql/catalyst/plans/logical/EventTimeWatermark.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 06196b5afb031..7a927e1e083b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -38,7 +38,7 @@ object EventTimeWatermark { case class EventTimeWatermark( eventTime: Attribute, delay: CalendarInterval, - child: LogicalPlan) extends LogicalPlan { + child: LogicalPlan) extends UnaryNode { // Update the metadata on the eventTime column to include the desired delay. override val output: Seq[Attribute] = child.output.map { a => @@ -60,6 +60,4 @@ case class EventTimeWatermark( a } } - - override val children: Seq[LogicalPlan] = child :: Nil } From 9948b860aca42bbc6478ddfbc0ff590adb00c2f3 Mon Sep 17 00:00:00 2001 From: smurakozi Date: Wed, 6 Dec 2017 13:22:08 -0800 Subject: [PATCH 362/712] [SPARK-22516][SQL] Bump up Univocity version to 2.5.9 ## What changes were proposed in this pull request? There was a bug in Univocity Parser that causes the issue in SPARK-22516. This was fixed by upgrading from 2.5.4 to 2.5.9 version of the library : **Executing** ``` spark.read.option("header","true").option("inferSchema", "true").option("multiLine", "true").option("comment", "g").csv("test_file_without_eof_char.csv").show() ``` **Before** ``` ERROR Executor: Exception in task 0.0 in stage 6.0 (TID 6) com.univocity.parsers.common.TextParsingException: java.lang.IllegalArgumentException - Unable to skip 1 lines from line 2. End of input reached ... Internal state when error was thrown: line=3, column=0, record=2, charIndex=31 at com.univocity.parsers.common.AbstractParser.handleException(AbstractParser.java:339) at com.univocity.parsers.common.AbstractParser.parseNext(AbstractParser.java:475) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser$$anon$1.next(UnivocityParser.scala:281) at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) ``` **After** ``` +-------+-------+ |column1|column2| +-------+-------+ | abc| def| +-------+-------+ ``` ## How was this patch tested? The already existing `CSVSuite.commented lines in CSV data` test was extended to parse the file also in multiline mode. The test input file was modified to also include a comment in the last line. Author: smurakozi Closes #19906 from smurakozi/SPARK-22516. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- sql/core/pom.xml | 2 +- .../src/test/resources/test-data/comments.csv | 1 + .../execution/datasources/csv/CSVSuite.scala | 23 +++++++++++-------- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2c68b73095c4d..3b5a6945d8c4b 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -180,7 +180,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.4.jar +univocity-parsers-2.5.9.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 2aaac600b3ec3..64136baa01326 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -181,7 +181,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.4.jar +univocity-parsers-2.5.9.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 4db3fea008ee9..93010c606cf45 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.4 + 2.5.9 jar diff --git a/sql/core/src/test/resources/test-data/comments.csv b/sql/core/src/test/resources/test-data/comments.csv index 6275be7285b36..c0ace46db8c00 100644 --- a/sql/core/src/test/resources/test-data/comments.csv +++ b/sql/core/src/test/resources/test-data/comments.csv @@ -4,3 +4,4 @@ 6,7,8,9,0,2015-08-21 16:58:01 ~0,9,8,7,6,2015-08-22 17:59:02 1,2,3,4,5,2015-08-23 18:00:42 +~ comment in last line to test SPARK-22516 - do not add empty line at the end of this file! \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e439699605abb..4fe45420b4e77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -483,18 +483,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("commented lines in CSV data") { - val results = spark.read - .format("csv") - .options(Map("comment" -> "~", "header" -> "false")) - .load(testFile(commentsFile)) - .collect() + Seq("false", "true").foreach { multiLine => - val expected = - Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), - Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), - Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + val results = spark.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false", "multiLine" -> multiLine)) + .load(testFile(commentsFile)) + .collect() - assert(results.toSeq.map(_.toSeq) === expected) + val expected = + Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), + Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), + Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + + assert(results.toSeq.map(_.toSeq) === expected) + } } test("inferring schema with commented lines in CSV data") { From f110a7f884cb09f01a20462038328ddc5662b46f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 6 Dec 2017 14:12:16 -0800 Subject: [PATCH 363/712] [SPARK-22693][SQL] CreateNamedStruct and InSet should not use global variables ## What changes were proposed in this pull request? CreateNamedStruct and InSet are using a global variable which is not needed. This can generate some unneeded entries in the constant pool. The PR removes the unnecessary mutable states and makes them local variables. ## How was this patch tested? added UT Author: Marco Gaido Author: Marco Gaido Closes #19896 from mgaido91/SPARK-22693. --- .../expressions/complexTypeCreator.scala | 27 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 22 +++++++-------- .../expressions/ComplexTypeSuite.scala | 7 +++++ .../catalyst/expressions/PredicateSuite.scala | 7 +++++ 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 087b21043b309..3dc2ee03a86e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -356,22 +356,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"$values = null;") + val valCodes = valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin + } val valuesCode = ctx.splitExpressionsWithCurrentInputs( - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + expressions = valCodes, + funcName = "createNamedStruct", + extraArguments = "Object[]" -> values :: Nil) ev.copy(code = s""" - |$values = new Object[${valExprs.size}]; + |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 04e669492ec6d..a42dd7ecf57de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -344,17 +344,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } else { "" } - ctx.addMutableState(setName, setTerm, - s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") - ev.copy(code = s""" - ${childGen.code} - boolean ${ev.isNull} = ${childGen.isNull}; - boolean ${ev.value} = false; - if (!${ev.isNull}) { - ${ev.value} = $setTerm.contains(${childGen.value}); - $setNull - } - """) + ev.copy(code = + s""" + |${childGen.code} + |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |if (!${ev.isNull}) { + | $setName $setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet(); + | ${ev.value} = $setTerm.contains(${childGen.value}); + | $setNull + |} + """.stripMargin) } override def sql: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b0eaad1c80f89..6dfca7d73a3df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -299,4 +300,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("=")) .checkInputDataTypes().isFailure) } + + test("SPARK-22693: CreateNamedStruct should not use global variables") { + val ctx = new CodegenContext + CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0079e4e8d6f74..95a0dfa057563 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -429,4 +430,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val infinity = Literal(Double.PositiveInfinity) checkEvaluation(EqualTo(infinity, infinity), true) } + + test("SPARK-22693: InSet should not use global variables") { + val ctx = new CodegenContext + InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } From 8ae004b4602266d1f210e4c1564246d590412c06 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 16:15:25 -0800 Subject: [PATCH 364/712] [SPARK-22688][SQL] Upgrade Janino version to 3.0.8 ## What changes were proposed in this pull request? This PR upgrade Janino version to 3.0.8. [Janino 3.0.8](https://janino-compiler.github.io/janino/changelog.html) includes an important fix to reduce the number of constant pool entries by using 'sipush' java bytecode. * SIPUSH bytecode is not used for short integer constant [#33](https://github.com/janino-compiler/janino/issues/33). Please see detail in [this discussion thread](https://github.com/apache/spark/pull/19518#issuecomment-346674976). ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #19890 from kiszk/SPARK-22688. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 2 +- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 +++--- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 3b5a6945d8c4b..1831f3378e852 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.7.jar +commons-compiler-3.0.8.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.7.jar +janino-3.0.8.jar java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 64136baa01326..fe14c05987327 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -35,7 +35,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.7.jar +commons-compiler-3.0.8.jar commons-compress-1.4.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -96,7 +96,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.7.jar +janino-3.0.8.jar java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar diff --git a/pom.xml b/pom.xml index 07bca9d267da0..52db79eaf036b 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.7 + 3.0.8 2.22.2 2.9.3 3.5.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 670c82eff9286..5c9e604a8d293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} @@ -1240,12 +1240,12 @@ object CodeGenerator extends Logging { evaluator.cook("generated.java", code.body) updateAndGetCompilationStats(evaluator) } catch { - case e: JaninoRuntimeException => + case e: InternalCompilerException => val msg = s"failed to compile: $e" logError(msg, e) val maxLines = SQLConf.get.loggingMaxLinesForCodegen logInfo(s"\n${CodeFormatter.format(code, maxLines)}") - throw new JaninoRuntimeException(msg, e) + throw new InternalCompilerException(msg, e) case e: CompileException => val msg = s"failed to compile: $e" logError(msg, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 657b265260135..787c1cfbfb3d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.JaninoRuntimeException +import org.codehaus.janino.InternalCompilerException import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -385,7 +385,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ try { GeneratePredicate.generate(expression, inputSchema) } catch { - case _ @ (_: JaninoRuntimeException | _: CompileException) if codeGenFallBack => + case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack => genInterpretedPredicate(expression, inputSchema) } } From d32337b1effd5359e1ce7e46893767c908d4b16a Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 7 Dec 2017 13:05:59 +0100 Subject: [PATCH 365/712] [SPARK-22721] BytesToBytesMap peak memory usage not accurate after reset() ## What changes were proposed in this pull request? BytesToBytesMap doesn't update peak memory usage before shrinking back to initial capacity in reset(), so after a disk spill one never knows what was the size of hash table was before spilling. ## How was this patch tested? Checked manually. Author: Juliusz Sompolski Closes #19915 from juliuszsompolski/SPARK-22721. --- .../main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 4fadfe36cd716..7fdcf22c45f73 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -879,6 +879,7 @@ public LongArray getArray() { * Reset this map to initialized state. */ public void reset() { + updatePeakMemoryUsed(); numKeys = 0; numValues = 0; freeArray(longArray); From c1e5688d1a44c152d27dcb9a04da22993d7ab826 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 7 Dec 2017 20:42:46 +0800 Subject: [PATCH 366/712] [SPARK-22672][SQL][TEST] Refactor ORC Tests ## What changes were proposed in this pull request? Since SPARK-20682, we have two `OrcFileFormat`s. This PR refactors ORC tests with three principles (with a few exceptions) 1. Move test suite into `sql/core`. 2. Create `HiveXXX` test suite in `sql/hive` by reusing `sql/core` test suite. 3. `OrcTest` will provide common helper functions and `val orcImp: String`. **Test Suites** *Native OrcFileFormat* - org.apache.spark.sql.hive.orc - OrcFilterSuite - OrcPartitionDiscoverySuite - OrcQuerySuite - OrcSourceSuite - o.a.s.sql.hive.orc - OrcHadoopFsRelationSuite *Hive built-in OrcFileFormat* - o.a.s.sql.hive.orc - HiveOrcFilterSuite - HiveOrcPartitionDiscoverySuite - HiveOrcQuerySuite - HiveOrcSourceSuite - HiveOrcHadoopFsRelationSuite **Hierarchy** ``` OrcTest -> OrcSuite -> OrcSourceSuite -> OrcQueryTest -> OrcQuerySuite -> OrcPartitionDiscoveryTest -> OrcPartitionDiscoverySuite -> OrcFilterSuite HadoopFsRelationTest -> OrcHadoopFsRelationSuite -> HiveOrcHadoopFsRelationSuite ``` Please note the followings. - Unlike the other test suites, `OrcHadoopFsRelationSuite` doesn't inherit `OrcTest`. It is inside `sql/hive` like `ParquetHadoopFsRelationSuite` due to the dependencies and follows the existing convention to use `val dataSourceName: String` - `OrcFilterSuite`s cannot reuse test cases due to the different function signatures using Hive 1.2.1 ORC classes and Apache ORC 1.4.1 classes. ## How was this patch tested? Pass the Jenkins tests with reorganized test suites. Author: Dongjoon Hyun Closes #19882 from dongjoon-hyun/SPARK-22672. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 -- .../datasources}/orc/OrcFilterSuite.scala | 74 ++-- .../orc/OrcPartitionDiscoverySuite.scala | 48 +-- .../datasources}/orc/OrcQuerySuite.scala | 350 ++++++++-------- .../datasources}/orc/OrcSourceSuite.scala | 137 ++----- .../execution/datasources}/orc/OrcTest.scala | 37 +- .../sql/hive/orc/HiveOrcFilterSuite.scala | 387 ++++++++++++++++++ .../orc/HiveOrcPartitionDiscoverySuite.scala | 25 ++ .../sql/hive/orc/HiveOrcQuerySuite.scala | 165 ++++++++ .../sql/hive/orc/HiveOrcSourceSuite.scala | 107 +++++ .../hive/orc/OrcHadoopFsRelationSuite.scala | 8 +- 11 files changed, 971 insertions(+), 395 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql/execution/datasources}/orc/OrcFilterSuite.scala (87%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql/execution/datasources}/orc/OrcPartitionDiscoverySuite.scala (82%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql/execution/datasources}/orc/OrcQuerySuite.scala (68%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql/execution/datasources}/orc/OrcSourceSuite.scala (63%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql/execution/datasources}/orc/OrcTest.scala (72%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8ddddbeee598f..5e077285ade55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2775,32 +2775,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-21791 ORC should support column names with dot") { - val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - withTempDir { dir => - val path = new File(dir, "orc").getCanonicalPath - Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path) - assert(spark.read.format(orc).load(path).collect().length == 2) - } - } - - test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { - val e = intercept[AnalysisException] { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - } - assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - } - - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { - withTable("spark_20728") { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { - case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass - } - assert(fileFormat == Some(classOf[OrcFileFormat])) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala similarity index 87% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index de6f0d67f1734..a5f6b68ee862e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -15,25 +15,32 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} -import org.apache.spark.sql.{Column, DataFrame, QueryTest} +import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ /** - * A test suite that tests ORC filter API based filter pushdown optimization. + * A test suite that tests Apache ORC filter API based filter pushdown optimization. + * OrcFilterSuite and HiveOrcFilterSuite is logically duplicated to provide the same test coverage. + * The difference are the packages containing 'Predicate' and 'SearchArgument' classes. + * - OrcFilterSuite uses 'org.apache.orc.storage.ql.io.sarg' package. + * - HiveOrcFilterSuite uses 'org.apache.hadoop.hive.ql.io.sarg' package. */ -class OrcFilterSuite extends QueryTest with OrcTest { +class OrcFilterSuite extends OrcTest with SharedSQLContext { + private def checkFilterPredicate( df: DataFrame, predicate: Predicate, @@ -55,7 +62,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) } @@ -99,7 +106,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") } @@ -284,40 +291,27 @@ class OrcFilterSuite extends QueryTest with OrcTest { test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => - // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked - // in string form in order to check filter creation including logical operators - // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` - // to produce string expression and then compare it to given string expression below. - // This might have to be changed after Hive version is upgraded. checkFilterPredicate( '_1.isNotNull, - """leaf-0 = (IS_NULL _1) - |expr = (not leaf-0)""".stripMargin.trim + "leaf-0 = (IS_NULL _1), expr = (not leaf-0)" ) checkFilterPredicate( '_1 =!= 1, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (EQUALS _1 1) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( !('_1 < 4), - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 4) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( '_1 < 2 || '_1 > 3, - """leaf-0 = (LESS_THAN _1 2) - |leaf-1 = (LESS_THAN_EQUALS _1 3) - |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " + + "expr = (or leaf-0 (not leaf-1))" ) checkFilterPredicate( '_1 < 2 && '_1 > 3, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 2) - |leaf-2 = (LESS_THAN_EQUALS _1 3) - |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), leaf-2 = (LESS_THAN_EQUALS _1 3), " + + "expr = (and (not leaf-0) leaf-1 (not leaf-2))" ) } } @@ -344,4 +338,30 @@ class OrcFilterSuite extends QueryTest with OrcTest { checkNoFilterPredicate('_1.isNotNull) } } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + import org.apache.spark.sql.sources._ + // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala similarity index 82% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index d1ce3f1e2f058..d1911ea7f32a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -15,19 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SharedSQLContext // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -35,28 +28,8 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import spark._ - import spark.implicits._ - - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal - - def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - - def makeOrcFile[T <: Product: ClassTag: TypeTag]( - data: Seq[T], path: File): Unit = { - data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) - } - - - def makeOrcFile[T <: Product: ClassTag: TypeTag]( - df: DataFrame, path: File): Unit = { - df.write.mode("overwrite").orc(path.getCanonicalPath) - } +abstract class OrcPartitionDiscoveryTest extends OrcTest { + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" protected def withTempTable(tableName: String)(f: => Unit): Unit = { try f finally spark.catalog.dropTempView(tableName) @@ -90,7 +63,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).createOrReplaceTempView("t") + spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -137,7 +110,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).createOrReplaceTempView("t") + spark.read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -186,8 +159,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read - .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + spark.read + .option("hive.exec.default.partition.name", defaultPartitionName) .orc(base.getCanonicalPath) .createOrReplaceTempView("t") @@ -228,8 +201,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read - .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + spark.read + .option("hive.exec.default.partition.name", defaultPartitionName) .orc(base.getCanonicalPath) .createOrReplaceTempView("t") @@ -253,3 +226,4 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B } } +class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala similarity index 68% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 1ffaf30311037..e00e057a18cc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -15,24 +15,27 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc +import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.{OrcConf, OrcFile} import org.apache.orc.OrcConf.COMPRESS -import org.scalatest.BeforeAndAfterAll +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.util.Utils @@ -57,7 +60,8 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { +abstract class OrcQueryTest extends OrcTest { + import testImplicits._ test("Read/write All Types") { val data = (0 to 255).map { i => @@ -73,7 +77,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes(StandardCharsets.UTF_8)) :: Nil) { file => - val bytes = read.orc(file).head().getAs[Array[Byte]](0) + val bytes = spark.read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, StandardCharsets.UTF_8) === "test") } } @@ -91,7 +95,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + spark.read.orc(file), data.toDF().collect()) } } @@ -172,7 +176,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.orc(file), + spark.read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -183,9 +187,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { spark.range(0, 10).write .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } // `compression` overrides `orc.compress`. @@ -194,9 +202,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("compression", "ZLIB") .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } } @@ -206,39 +218,39 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { spark.range(0, 10).write .option("compression", "ZLIB") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } withTempPath { file => spark.range(0, 10).write .option("compression", "SNAPPY") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("SNAPPY" === expectedCompressionKind.name()) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".snappy.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("SNAPPY" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } withTempPath { file => spark.range(0, 10).write .option("compression", "NONE") .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("NONE" === expectedCompressionKind.name()) - } - } - // Following codec is not supported in Hive 1.2.1, ignore it now - ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { - withTempPath { file => - spark.range(0, 10).write - .option("compression", "LZO") - .orc(file.getCanonicalPath) - val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression - assert("LZO" === expectedCompressionKind.name()) + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("NONE" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) } } @@ -256,22 +268,28 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple)) } - sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), + ignoreIfNotExists = true, + purge = false) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(spark.table("t"), data.map(Row.fromTuple)) } - sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), + ignoreIfNotExists = true, + purge = false) } test("self-join") { @@ -334,60 +352,16 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - spark.range(0, 10).select('id as "Acol").write.format("orc").save(path) - spark.read.format("orc").load(path).schema("Acol") + spark.range(0, 10).select('id as "Acol").write.orc(path) + spark.read.orc(path).schema("Acol") intercept[IllegalArgumentException] { - spark.read.format("orc").load(path).schema("acol") + spark.read.orc(path).schema("acol") } - checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"), + checkAnswer(spark.read.orc(path).select("acol").sort("acol"), (0 until 10).map(Row(_))) } } - test("SPARK-8501: Avoids discovery schema from empty ORC files") { - withTempPath { dir => - val path = dir.getCanonicalPath - - withTable("empty_orc") { - withTempView("empty", "single") { - spark.sql( - s"""CREATE TABLE empty_orc(key INT, value STRING) - |STORED AS ORC - |LOCATION '${dir.toURI}' - """.stripMargin) - - val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) - emptyDF.createOrReplaceTempView("empty") - - // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because - // Spark SQL ORC data source always avoids write empty ORC files. - spark.sql( - s"""INSERT INTO TABLE empty_orc - |SELECT key, value FROM empty - """.stripMargin) - - val errorMessage = intercept[AnalysisException] { - spark.read.orc(path) - }.getMessage - - assert(errorMessage.contains("Unable to infer schema for ORC")) - - val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) - singleRowDF.createOrReplaceTempView("single") - - spark.sql( - s"""INSERT INTO TABLE empty_orc - |SELECT key, value FROM single - """.stripMargin) - - val df = spark.read.orc(path) - assert(df.schema === singleRowDF.schema.asNullable) - checkAnswer(df, singleRowDF) - } - } - } - } - test("SPARK-10623 Enable ORC PPD") { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { @@ -405,7 +379,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) + spark.createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) val df = spark.read.orc(path) def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { @@ -440,77 +414,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") { - withTempView("single") { - val singleRowDF = Seq((0, "foo")).toDF("key", "value") - singleRowDF.createOrReplaceTempView("single") - - Seq("true", "false").foreach { orcConversion => - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) { - withTable("dummy_orc") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark.sql( - s""" - |CREATE TABLE dummy_orc(key INT, value STRING) - |STORED AS ORC - |LOCATION '${dir.toURI}' - """.stripMargin) - - spark.sql( - s""" - |INSERT INTO TABLE dummy_orc - |SELECT key, value FROM single - """.stripMargin) - - val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") - checkAnswer(df, singleRowDF) - - val queryExecution = df.queryExecution - if (orcConversion == "true") { - queryExecution.analyzed.collectFirst { - case _: LogicalRelation => () - }.getOrElse { - fail(s"Expecting the query plan to convert orc to data sources, " + - s"but got:\n$queryExecution") - } - } else { - queryExecution.analyzed.collectFirst { - case _: HiveTableRelation => () - }.getOrElse { - fail(s"Expecting no conversion from orc to data sources, " + - s"but got:\n$queryExecution") - } - } - } - } - } - } - } - } - - test("converted ORC table supports resolving mixed case field") { - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - withTable("dummy_orc") { - withTempPath { dir => - val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") - df.write - .partitionBy("partitionValue") - .mode("overwrite") - .orc(dir.getAbsolutePath) - - spark.sql(s""" - |create external table dummy_orc (id long, valueField long) - |partitioned by (partitionValue int) - |stored as orc - |location "${dir.toURI}"""".stripMargin) - spark.sql(s"msck repair table dummy_orc") - checkAnswer(spark.sql("select * from dummy_orc"), df) - } - } - } - } - test("SPARK-14962 Produce correct results on array type with isnotnull") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(Array(i))) @@ -544,7 +447,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + spark.createDataFrame(data).toDF("a").repartition(10) + .write.orc(file.getCanonicalPath) val df = spark.read.orc(file.getCanonicalPath).where("a == 2") val actual = stripSparkFilter(df).count() @@ -563,7 +467,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. - createDataFrame(data).toDF("a").repartition(10).write.orc(file.getCanonicalPath) + spark.createDataFrame(data).toDF("a").repartition(10) + .write.orc(file.getCanonicalPath) val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'") val actual = stripSparkFilter(df).count() @@ -596,14 +501,18 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Empty schema does not read data from ORC file") { val data = Seq((1, 1), (2, 2)) withOrcFile(data) { path => - val requestedSchema = StructType(Nil) val conf = new Configuration() - val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) - val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) - assert(maybeOrcReader.isDefined) - val orcRecordReader = new SparkOrcNewRecordReader( - maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, "") + conf.setBoolean("hive.io.file.read.all.columns", false) + + val orcRecordReader = { + val file = new File(path).listFiles().find(_.getName.endsWith(".snappy.orc")).head + val split = new FileSplit(new Path(file.toURI), 0, file.length, Array.empty[String]) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val oif = new OrcInputFormat[OrcStruct] + oif.createRecordReader(split, hadoopAttemptContext) + } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) try { @@ -614,27 +523,88 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("read from multiple orc input paths") { - val path1 = Utils.createTempDir() - val path2 = Utils.createTempDir() - makeOrcFile((1 to 10).map(Tuple1.apply), path1) - makeOrcFile((1 to 10).map(Tuple1.apply), path2) - assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count()) - } + test("read from multiple orc input paths") { + val path1 = Utils.createTempDir() + val path2 = Utils.createTempDir() + makeOrcFile((1 to 10).map(Tuple1.apply), path1) + makeOrcFile((1 to 10).map(Tuple1.apply), path2) + val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) + assert(df.count() == 20) + } +} + +class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { + import testImplicits._ + + test("LZO compression options for writing to an ORC file") { + withTempPath { file => + spark.range(0, 10).write + .option("compression", "LZO") + .orc(file.getCanonicalPath) + + val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".lzo.orc")) + assert(maybeOrcFile.isDefined) + + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + assert("LZO" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name) + } + } + + test("Schema discovery on empty ORC files") { + // SPARK-8501 is fixed. + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempView("empty", "single") { + spark.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |USING ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.createOrReplaceTempView("empty") + + // This creates 1 empty ORC file with ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val df = spark.read.orc(path) + assert(df.schema === emptyDF.schema.asNullable) + checkAnswer(df, emptyDF) + } + } + } + } + + test("SPARK-21791 ORC should support column names with dot") { + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.orc(path) + assert(spark.read.orc(path).collect().length == 2) + } + } test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { - Seq( - ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]), - ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { case (i, format) => - - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> i) { - withTable("spark_20728") { - sql("CREATE TABLE spark_20728(a INT) USING ORC") - val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { - case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass - } - assert(fileFormat == Some(format)) + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { + val e = intercept[AnalysisException] { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + } + assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) + } + + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "native") { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass } + assert(fileFormat == Some(classOf[OrcFileFormat])) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala similarity index 63% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 2a086be57f517..6f5f2fd795f74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File import java.util.Locale @@ -23,50 +23,30 @@ import java.util.Locale import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.execution.datasources.orc.OrcOptions -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import spark._ +abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { + import testImplicits._ var orcTableDir: File = null var orcTableAsDir: File = null - override def beforeAll(): Unit = { + protected override def beforeAll(): Unit = { super.beforeAll() orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - - // Hack: to prepare orc data files using hive external tables orcTableDir = Utils.createTempDir("orctests", "sparksql") - import org.apache.spark.sql.hive.test.TestHive.implicits._ sparkContext .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() - .createOrReplaceTempView(s"orc_temp_table") - - sql( - s"""CREATE EXTERNAL TABLE normal_orc( - | intField INT, - | stringField STRING - |) - |STORED AS ORC - |LOCATION '${orcTableAsDir.toURI}' - """.stripMargin) - - sql( - s"""INSERT INTO TABLE normal_orc - |SELECT intField, stringField FROM orc_temp_table - """.stripMargin) + .createOrReplaceTempView("orc_temp_table") } test("create temporary orc table") { @@ -152,56 +132,13 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val conf = sqlContext.sessionState.conf + val conf = spark.sessionState.conf val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) assert(option.compressionCodec == "NONE") } - test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { - val location = Utils.createTempDir() - val uri = location.toURI - try { - hiveClient.runSqlHive("USE default") - hiveClient.runSqlHive( - """ - |CREATE EXTERNAL TABLE hive_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc""".stripMargin) - // Hive throws an exception if I assign the location in the create table statement. - hiveClient.runSqlHive( - s"ALTER TABLE hive_orc SET LOCATION '$uri'") - hiveClient.runSqlHive( - """ - |INSERT INTO TABLE hive_orc - |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) - |FROM (SELECT 1) t""".stripMargin) - - // We create a different table in Spark using the same schema which points to - // the same location. - spark.sql( - s""" - |CREATE EXTERNAL TABLE spark_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc - |LOCATION '$uri'""".stripMargin) - val result = Row("a", "b ", "c", Seq("d ")) - checkAnswer(spark.table("hive_orc"), result) - checkAnswer(spark.table("spark_orc"), result) - } finally { - hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc") - hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc") - Utils.deleteRecursively(location) - } - } - test("SPARK-21839: Add SQL config for ORC compression") { - val conf = sqlContext.sessionState.conf + val conf = spark.sessionState.conf // Test if the default of spark.sql.orc.compression.codec is snappy assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY") @@ -225,13 +162,28 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } } -class OrcSourceSuite extends OrcSuite { - override def beforeAll(): Unit = { +class OrcSourceSuite extends OrcSuite with SharedSQLContext { + + protected override def beforeAll(): Unit = { super.beforeAll() + sql( + s"""CREATE TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |USING ORC + |LOCATION '${orcTableAsDir.toURI}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_source - |USING org.apache.spark.sql.hive.orc + |USING ORC |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) @@ -239,43 +191,10 @@ class OrcSourceSuite extends OrcSuite { spark.sql( s"""CREATE TEMPORARY VIEW normal_orc_as_source - |USING org.apache.spark.sql.hive.orc + |USING ORC |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' |) """.stripMargin) } - - test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { - // The `LessThan` should be converted while the `StringContains` shouldn't - val schema = new StructType( - Array( - StructField("a", IntegerType, nullable = true), - StructField("b", StringType, nullable = true))) - assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim - ) { - OrcFilters.createFilter(schema, Array( - LessThan("a", 10), - StringContains("b", "prefix") - )).get.toString - } - - // The `LessThan` should be converted while the whole inner `And` shouldn't - assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim - ) { - OrcFilters.createFilter(schema, Array( - LessThan("a", 10), - Not(And( - GreaterThan("a", 1), - StringContains("b", "prefix") - )) - )).get.toString - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala similarity index 72% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index a2f08c5ba72c6..d94cb850ed2a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -15,20 +15,51 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.orc +package org.apache.spark.sql.execution.datasources.orc import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { +/** + * OrcTest + * -> OrcSuite + * -> OrcSourceSuite + * -> HiveOrcSourceSuite + * -> OrcQueryTests + * -> OrcQuerySuite + * -> HiveOrcQuerySuite + * -> OrcPartitionDiscoveryTest + * -> OrcPartitionDiscoverySuite + * -> HiveOrcPartitionDiscoverySuite + * -> OrcFilterSuite + * -> HiveOrcFilterSuite + */ +abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { import testImplicits._ + val orcImp: String = "native" + + private var originalConfORCImplementation = "native" + + protected override def beforeAll(): Unit = { + super.beforeAll() + originalConfORCImplementation = conf.getConf(ORC_IMPLEMENTATION) + conf.setConf(ORC_IMPLEMENTATION, orcImp) + } + + protected override def afterAll(): Unit = { + conf.setConf(ORC_IMPLEMENTATION, originalConfORCImplementation) + super.afterAll() + } + /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala new file mode 100644 index 0000000000000..283037caf4a9b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Hive ORC filter API based filter pushdown optimization. + */ +class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton { + + override val orcImp: String = "hive" + + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + private def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala + assert(operator.map(_.getOperator).contains(filterOperator)) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + private def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + private def checkNoFilterPredicate + (predicate: Predicate) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= false, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(false) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - decimal") { + withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - timestamp") { + val timeString = "2015-08-20 14:57:00" + val timestamps = (1 to 4).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + new Timestamp(milliseconds) + } + withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked + // in string form in order to check filter creation including logical operators + // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` + // to produce string expression and then compare it to given string expression below. + // This might have to be changed after Hive version is upgraded. + checkFilterPredicate( + '_1.isNotNull, + """leaf-0 = (IS_NULL _1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 =!= 1, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (EQUALS _1 1) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + !('_1 < 4), + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 4) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 2) + |leaf-2 = (LESS_THAN_EQUALS _1 3) + |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + ) + } + } + + test("no filter pushdown - non-supported types") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + // ArrayType + withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => + checkNoFilterPredicate('_1.isNull) + } + // BinaryType + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkNoFilterPredicate('_1 <=> 1.b) + } + // DateType + val stringDate = "2015-01-01" + withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => + checkNoFilterPredicate('_1 === Date.valueOf(stringDate)) + } + // MapType + withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => + checkNoFilterPredicate('_1.isNotNull) + } + } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + import org.apache.spark.sql.sources._ + // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..ab9b492f347c6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcPartitionDiscoverySuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.spark.sql.execution.datasources.orc.OrcPartitionDiscoveryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveOrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with TestHiveSingleton { + override val orcImp: String = "hive" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala new file mode 100644 index 0000000000000..7244c369bd3f4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.orc.OrcQueryTest +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { + import testImplicits._ + + override val orcImp: String = "hive" + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempView("empty", "single") { + spark.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.createOrReplaceTempView("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + spark.read.orc(path) + }.getMessage + + assert(errorMessage.contains("Unable to infer schema for ORC")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.createOrReplaceTempView("single") + + spark.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = spark.read.orc(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } + + test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") { + withTempView("single") { + val singleRowDF = Seq((0, "foo")).toDF("key", "value") + singleRowDF.createOrReplaceTempView("single") + + Seq("true", "false").foreach { orcConversion => + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) { + withTable("dummy_orc") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.sql( + s""" + |CREATE TABLE dummy_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '${dir.toURI}' + """.stripMargin) + + spark.sql( + s""" + |INSERT INTO TABLE dummy_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + if (orcConversion == "true") { + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => () + }.getOrElse { + fail(s"Expecting the query plan to convert orc to data sources, " + + s"but got:\n$queryExecution") + } + } else { + queryExecution.analyzed.collectFirst { + case _: HiveTableRelation => () + }.getOrElse { + fail(s"Expecting no conversion from orc to data sources, " + + s"but got:\n$queryExecution") + } + } + } + } + } + } + } + } + + test("converted ORC table supports resolving mixed case field") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { + withTable("dummy_orc") { + withTempPath { dir => + val df = spark.range(5).selectExpr("id", "id as valueField", "id as partitionValue") + df.write + .partitionBy("partitionValue") + .mode("overwrite") + .orc(dir.getAbsolutePath) + + spark.sql(s""" + |create external table dummy_orc (id long, valueField long) + |partitioned by (partitionValue int) + |stored as orc + |location "${dir.toURI}"""".stripMargin) + spark.sql(s"msck repair table dummy_orc") + checkAnswer(spark.sql("select * from dummy_orc"), df) + } + } + } + } + + test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { + Seq( + ("native", classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat]), + ("hive", classOf[org.apache.spark.sql.hive.orc.OrcFileFormat])).foreach { + case (orcImpl, format) => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTable("spark_20728") { + sql("CREATE TABLE spark_20728(a INT) USING ORC") + val fileFormat = sql("SELECT * FROM spark_20728").queryExecution.analyzed.collectFirst { + case l: LogicalRelation => + l.relation.asInstanceOf[HadoopFsRelation].fileFormat.getClass + } + assert(fileFormat == Some(format)) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala new file mode 100644 index 0000000000000..17b7d8cfe127e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.orc.OrcSuite +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + +class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { + + override val orcImp: String = "hive" + + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.toURI}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + } + + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { + val location = Utils.createTempDir() + val uri = location.toURI + try { + hiveClient.runSqlHive("USE default") + hiveClient.runSqlHive( + """ + |CREATE EXTERNAL TABLE hive_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc""".stripMargin) + // Hive throws an exception if I assign the location in the create table statement. + hiveClient.runSqlHive( + s"ALTER TABLE hive_orc SET LOCATION '$uri'") + hiveClient.runSqlHive( + """ + |INSERT INTO TABLE hive_orc + |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) + |FROM (SELECT 1) t""".stripMargin) + + // We create a different table in Spark using the same schema which points to + // the same location. + spark.sql( + s""" + |CREATE EXTERNAL TABLE spark_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc + |LOCATION '$uri'""".stripMargin) + val result = Row("a", "b ", "c", Seq("d ")) + checkAnswer(spark.table("hive_orc"), result) + checkAnswer(spark.table("spark_orc"), result) + } finally { + hiveClient.runSqlHive("DROP TABLE IF EXISTS hive_orc") + hiveClient.runSqlHive("DROP TABLE IF EXISTS spark_orc") + Utils.deleteRecursively(location) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index ba0a7605da71c..f87162f94c01a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ - override val dataSourceName: String = classOf[OrcFileFormat].getCanonicalName + override val dataSourceName: String = + classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName // ORC does not play well with NullType and UDT. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { @@ -116,3 +117,8 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + +class HiveOrcHadoopFsRelationSuite extends OrcHadoopFsRelationSuite { + override val dataSourceName: String = + classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName +} From e103adf45ae3e80626554370691987d8297dda0c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Dec 2017 20:45:11 +0800 Subject: [PATCH 367/712] [SPARK-22703][SQL] make ColumnarRow an immutable view ## What changes were proposed in this pull request? Similar to https://github.com/apache/spark/pull/19842 , we should also make `ColumnarRow` an immutable view, and move forward to make `ColumnVector` public. ## How was this patch tested? Existing tests. The performance concern should be same as https://github.com/apache/spark/pull/19842 . Author: Wenchen Fan Closes #19898 from cloud-fan/row-id. --- .../vectorized/AggregateHashMap.java | 9 +-- .../vectorized/ArrowColumnVector.java | 1 - .../execution/vectorized/ColumnVector.java | 15 ++--- .../execution/vectorized/ColumnarBatch.java | 22 +++---- .../sql/execution/vectorized/ColumnarRow.java | 66 ++++++++++--------- .../vectorized/MutableColumnarRow.java | 47 +++++++------ .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 5 -- .../aggregate/HashAggregateExec.scala | 9 +-- .../VectorizedHashMapGenerator.scala | 5 +- .../vectorized/ColumnarBatchSuite.scala | 5 -- 12 files changed, 89 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 9467435435d1f..24260b05194a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -41,7 +41,7 @@ public class AggregateHashMap { private OnHeapColumnVector[] columnVectors; - private ColumnarBatch batch; + private MutableColumnarRow aggBufferRow; private int[] buckets; private int numBuckets; private int numRows = 0; @@ -63,7 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); - batch = new ColumnarBatch(schema, columnVectors, capacity); + aggBufferRow = new MutableColumnarRow(columnVectors); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } @@ -72,14 +72,15 @@ public AggregateHashMap(StructType schema) { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarRow findOrInsert(long key) { + public MutableColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); columnVectors[1].putLong(numRows, 0); buckets[idx] = numRows++; } - return batch.getRow(buckets[idx]); + aggBufferRow.rowId = buckets[idx]; + return aggBufferRow; } @VisibleForTesting diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 0071bd66760be..1f1347ccd315e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -323,7 +323,6 @@ public ArrowColumnVector(ValueVector vector) { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index cca14911fbb28..e6b87519239dd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -157,18 +157,16 @@ public abstract class ColumnVector implements AutoCloseable { /** * Returns a utility object to get structs. */ - public ColumnarRow getStruct(int rowId) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId) { + return new ColumnarRow(this, rowId); } /** * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarRow getStruct(int rowId, int size) { - resultStruct.rowId = rowId; - return resultStruct; + public final ColumnarRow getStruct(int rowId, int size) { + return getStruct(rowId); } /** @@ -216,11 +214,6 @@ public MapData getMap(int ordinal) { */ protected DataType type; - /** - * Reusable Struct holder for getStruct(). - */ - protected ColumnarRow resultStruct; - /** * The Dictionary for this column. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2f5fb360b226f..a9d09aa679726 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.StructType; /** @@ -40,10 +41,10 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - final ColumnVector[] columns; + private final ColumnVector[] columns; - // Staging row returned from getRow. - final ColumnarRow row; + // Staging row returned from `getRow`. + private final MutableColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -58,10 +59,10 @@ public void close() { /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator rowIterator() { + public Iterator rowIterator() { final int maxRows = numRows; - final ColumnarRow row = new ColumnarRow(columns); - return new Iterator() { + final MutableColumnarRow row = new MutableColumnarRow(columns); + return new Iterator() { int rowId = 0; @Override @@ -70,7 +71,7 @@ public boolean hasNext() { } @Override - public ColumnarRow next() { + public InternalRow next() { if (rowId >= maxRows) { throw new NoSuchElementException(); } @@ -133,9 +134,8 @@ public void setNumRows(int numRows) { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarRow getRow(int rowId) { - assert(rowId >= 0); - assert(rowId < numRows); + public InternalRow getRow(int rowId) { + assert(rowId >= 0 && rowId < numRows); row.rowId = rowId; return row; } @@ -144,6 +144,6 @@ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; this.columns = columns; this.capacity = capacity; - this.row = new ColumnarRow(columns); + this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index cabb7479525d9..95c0d09873d67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -28,30 +28,32 @@ * to be reused, callers should copy the data out if it needs to be stored. */ public final class ColumnarRow extends InternalRow { - protected int rowId; - private final ColumnVector[] columns; - - // Ctor used if this is a struct. - ColumnarRow(ColumnVector[] columns) { - this.columns = columns; + // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + private final ColumnVector data; + private final int rowId; + private final int numFields; + + ColumnarRow(ColumnVector data, int rowId) { + assert (data.dataType() instanceof StructType); + this.data = data; + this.rowId = rowId; + this.numFields = ((StructType) data.dataType()).size(); } - public ColumnVector[] columns() { return columns; } - @Override - public int numFields() { return columns.length; } + public int numFields() { return numFields; } /** * Revisit this. This is expensive. This is currently only used in test paths. */ @Override public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); + GenericInternalRow row = new GenericInternalRow(numFields); for (int i = 0; i < numFields(); i++) { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = columns[i].dataType(); + DataType dt = data.getChildColumn(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -91,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); + final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); + if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; + return data.getChildColumn(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index f272cc163611b..06602c147dfe9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -28,17 +28,24 @@ /** * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash - * aggregate. + * aggregate, and {@link ColumnarBatch} to save object creation. * * Note that this class intentionally has a lot of duplicated code with {@link ColumnarRow}, to * avoid java polymorphism overhead by keeping {@link ColumnarRow} and this class final classes. */ public final class MutableColumnarRow extends InternalRow { public int rowId; - private final WritableColumnVector[] columns; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; - public MutableColumnarRow(WritableColumnVector[] columns) { + public MutableColumnarRow(ColumnVector[] columns) { this.columns = columns; + this.writableColumns = null; + } + + public MutableColumnarRow(WritableColumnVector[] writableColumns) { + this.columns = writableColumns; + this.writableColumns = writableColumns; } @Override @@ -225,54 +232,54 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - columns[ordinal].putNull(rowId); + writableColumns[ordinal].putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putBoolean(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putByte(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putShort(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putInt(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putLong(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putFloat(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDouble(rowId, value); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDecimal(rowId, value, precision); + writableColumns[ordinal].putNotNull(rowId); + writableColumns[ordinal].putDecimal(rowId, value, precision); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 806d0291a6c49..5f1b9885334b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -547,7 +547,7 @@ protected void reserveInternal(int newCapacity) { } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 6e7f74ce12f16..f12772ede575d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -558,7 +558,7 @@ protected void reserveInternal(int newCapacity) { if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } - } else if (resultStruct != null) { + } else if (childColumns != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bea4cc97142d..7c053b579442c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -74,7 +74,6 @@ public void close() { dictionaryIds = null; } dictionary = null; - resultStruct = null; } public void reserve(int requiredCapacity) { @@ -673,23 +672,19 @@ protected WritableColumnVector(int capacity, DataType type) { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } - this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); - this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; - this.resultStruct = null; } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 26d8cd7278353..9cadd13999e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,9 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState( - s"java.util.Iterator<${classOf[ColumnarRow].getName}>", - iterTermForFastHashMap) + ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() @@ -674,7 +672,7 @@ case class HashAggregateExec( """.stripMargin } - // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow + // Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -687,10 +685,9 @@ case class HashAggregateExec( bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) - val columnarRowCls = classOf[ColumnarRow].getName s""" |while ($iterTermForFastHashMap.hasNext()) { - | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next(); + | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 44ba539ebf7c2..f04cd48072f17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnarRow, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ /** @@ -231,7 +232,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator<${classOf[ColumnarRow].getName}> rowIterator() { + |public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() { | batch.setNumRows(numRows); | return batch.rowIterator(); |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0ae4f2d117609..c9c6bee513b53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -751,11 +751,6 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(1, 5.67) val s = column.getStruct(0) - assert(s.columns()(0).getInt(0) == 123) - assert(s.columns()(0).getInt(1) == 456) - assert(s.columns()(1).getDouble(0) == 3.45) - assert(s.columns()(1).getDouble(1) == 5.67) - assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) From ea2fbf41973e62b28b9d5fa772512ddba5c00d96 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Dec 2017 20:55:35 +0800 Subject: [PATCH 368/712] [SPARK-22705][SQL] Case, Coalesce, and In use less global variables ## What changes were proposed in this pull request? This PR accomplishes the following two items. 1. Reduce # of global variables from two to one for generated code of `Case` and `Coalesce` and remove global variables for generated code of `In`. 2. Make lifetime of global variable local within an operation Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures that an variable is not passed to arguments in a method split by `CodegenContext.splitExpressions()`, which is addressed by #19865. ## How was this patch tested? Added new tests into `PredicateSuite`, `NullExpressionsSuite`, and `ConditionalExpressionSuite`. Author: Kazuaki Ishizaki Closes #19901 from kiszk/SPARK-22705. --- .../expressions/conditionalExpressions.scala | 70 +++++++++++-------- .../expressions/nullExpressions.scala | 19 +++-- .../sql/catalyst/expressions/predicates.scala | 31 +++++--- .../ConditionalExpressionSuite.scala | 7 ++ .../expressions/NullExpressionsSuite.scala | 7 ++ .../catalyst/expressions/PredicateSuite.scala | 6 ++ 6 files changed, 91 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae5f7140847db..53c3b226895ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,13 +180,18 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the first successful condition is met or not. - // It is initialized to `false` and it is set to `true` when the first condition which - // evaluates to `true` is met and therefore is not needed to go on anymore on the computation - // of the following conditions. - val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + // This variable holds the state of the result: + // -1 means the condition is not met yet and the result is unknown. + val NOT_MATCHED = -1 + // 0 means the condition is met and result is not null. + val HAS_NONNULL = 0 + // 1 means the condition is met and result is null. + val HAS_NULL = 1 + // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, + // We won't go on anymore on the computation. + val resultState = ctx.freshName("caseWhenResultState") + val tmpResult = ctx.freshName("caseWhenTmpResult") + ctx.addMutableState(ctx.javaType(dataType), tmpResult) // these blocks are meant to be inside a // do { @@ -200,9 +205,8 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | ${ev.isNull} = ${res.isNull}; - | ${ev.value} = ${res.value}; - | $conditionMet = true; + | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + | $tmpResult = ${res.value}; | continue; |} """.stripMargin @@ -212,59 +216,63 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |${ev.isNull} = ${res.isNull}; - |${ev.value} = ${res.value}; + |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + |$tmpResult = ${res.value}; """.stripMargin } val allConditions = cases ++ elseCode // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_1(i); + // if(caseWhenResultState != -1) { // continue; // } - // conditionMet = caseWhen_2(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_2(i); + // if(caseWhenResultState != -1) { // continue; // } // ... // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; + // private byte caseWhen_1234() { + // byte caseWhenResultState = -1; // do { // // here the evaluation of the conditions // } while (false); - // return conditionMet; + // return caseWhenResultState; // } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BOOLEAN, + returnType = ctx.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BOOLEAN} $conditionMet = false; + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); - |return $conditionMet; + |return $resultState; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$conditionMet = $funcCall; - |if ($conditionMet) { + |$resultState = $funcCall; + |if ($resultState != $NOT_MATCHED) { | continue; |} """.stripMargin }.mkString) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.JAVA_BOOLEAN} $conditionMet = false; - do { - $codes - } while (false);""") + ev.copy(code = + s""" + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |$tmpResult = ${ctx.defaultValue(dataType)}; + |do { + | $codes + |} while (false); + |// TRUE if any condition is met and the result is null, or no any condition is met. + |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); + |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 26c9a41efc9f9..294cdcb2e9546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + val tmpIsNull = ctx.freshName("coalesceTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,26 +81,30 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | ${ev.isNull} = false; + | $tmpIsNull = false; | ${ev.value} = ${eval.value}; | continue; |} """.stripMargin } + val resultType = ctx.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", + returnType = resultType, makeSplitFunction = func => s""" + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $func |} while (false); + |return ${ev.value}; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (!${ev.isNull}) { + |${ev.value} = $funcCall; + |if (!$tmpIsNull) { | continue; |} """.stripMargin @@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |${ev.isNull} = true; - |${ev.value} = ${ctx.defaultValue(dataType)}; + |$tmpIsNull = true; + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); + |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a42dd7ecf57de..8eb41addaf689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,8 +237,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + // inTmpResult has 3 possible values: + // -1 means no matches found and there is at least one value in the list evaluated to null + val HAS_NULL = -1 + // 0 means no matches found and all values in the list are not null + val NOT_MATCHED = 0 + // 1 means one value in the list is matched + val MATCHED = 1 + val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -246,10 +252,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | ${ev.isNull} = true; + | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; + | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; | continue; |} """.stripMargin) @@ -257,17 +262,19 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: Nil, + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, + returnType = ctx.JAVA_BYTE, makeSplitFunction = body => s""" |do { | $body |} while (false); + |return $tmpResult; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (${ev.value}) { + |$tmpResult = $funcCall; + |if ($tmpResult == $MATCHED) { | continue; |} """.stripMargin @@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |${ev.value} = false; - |${ev.isNull} = ${valueGen.isNull}; - |if (!${ev.isNull}) { + |byte $tmpResult = $HAS_NULL; + |if (!${valueGen.isNull}) { + | $tmpResult = 0; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} + |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); + |final boolean ${ev.value} = ($tmpResult == $MATCHED); """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3e11c3d2d4fe3..60d84aae1fa3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper IndexedSeq((Literal(12) === Literal(1), Literal(42)), (Literal(12) === Literal(42), Literal(1)))) } + + test("SPARK-22705: case when should use less global variables") { + val ctx = new CodegenContext() + CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 40ef7770da33f..a23cd95632770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(inputs), "x_1") } + test("SPARK-22705: Coalesce should use less global variables") { + val ctx = new CodegenContext() + Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } + test("AtLeastNNonNulls should not throw 64kb exception") { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 95a0dfa057563..15cb0bea08f17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -246,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal(1.0D), sets), true) } + test("SPARK-22705: In should use less global variables") { + val ctx = new CodegenContext() + In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null From 2be448260d8115969e3d89b913c733943f151915 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 7 Dec 2017 20:59:47 +0800 Subject: [PATCH 369/712] [SPARK-22452][SQL] Add getInt, getLong, getBoolean to DataSourceV2Options - Implemented methods getInt, getLong, getBoolean for DataSourceV2Options - Added new unit tests to exercise these methods Author: Sunitha Kambhampati Closes #19902 from skambha/spark22452. --- .../sql/sources/v2/DataSourceV2Options.java | 31 +++++++++++++++++++ .../sources/v2/DataSourceV2OptionsSuite.scala | 31 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java index 9a89c8193dd6e..b2c908dc73a61 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -49,4 +49,35 @@ public DataSourceV2Options(Map originalMap) { public Optional get(String key) { return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public boolean getBoolean(String key, boolean defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public int getInt(String key, int defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public long getLong(String key, long defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala index 933f4075bcc8a..752d3c193cc74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -37,4 +37,35 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } + + test("getInt") { + val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + assert(options.getInt("numFOO", 10) == 1) + assert(options.getInt("numFOO2", 10) == 10) + + intercept[NumberFormatException]{ + options.getInt("foo", 1) + } + } + + test("getBoolean") { + val options = new DataSourceV2Options( + Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) + assert(options.getBoolean("isFoo", false)) + assert(!options.getBoolean("isFoo2", true)) + assert(options.getBoolean("isBar", true)) + assert(!options.getBoolean("isBar", false)) + assert(!options.getBoolean("FOO", true)) + } + + test("getLong") { + val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + "foo" -> "bar").asJava) + assert(options.getLong("numFOO", 0L) == 9223372036854775807L) + assert(options.getLong("numFoo2", -1L) == -1L) + + intercept[NumberFormatException]{ + options.getLong("foo", 0L) + } + } } From beb717f64802de03836a0fed8cce4d813be1769a Mon Sep 17 00:00:00 2001 From: Brad Kaiser Date: Thu, 7 Dec 2017 21:04:09 +0800 Subject: [PATCH 370/712] [SPARK-22618][CORE] Catch exception in removeRDD to stop jobs from dying ## What changes were proposed in this pull request? I propose that BlockManagerMasterEndpoint.removeRdd() should catch and log any IOExceptions it receives. As it is now, the exception can bubble up to the main thread and kill user applications when called from RDD.unpersist(). I think this change is a better experience for the end user. I chose to catch the exception in BlockManagerMasterEndpoint.removeRdd() instead of RDD.unpersist() because this way the RDD.unpersist() blocking option will still work correctly. Otherwise, blocking will get short circuited by the first error. ## How was this patch tested? This patch was tested with a job that shows the job killing behavior mentioned above. rxin, it looks like you originally wrote this method, I would appreciate it if you took a look. Thanks. This contribution is my original work and is licensed under the project's open source license. Author: Brad Kaiser Closes #19836 from brad-kaiser/catch-unpersist-exception. --- .../storage/BlockManagerMasterEndpoint.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56d0266b8edad..89a6a71a589a1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.IOException import java.util.{HashMap => JHashMap} import scala.collection.JavaConverters._ @@ -159,11 +160,16 @@ class BlockManagerMasterEndpoint( // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. val removeMsg = RemoveRdd(rddId) - Future.sequence( - blockManagerInfo.values.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + + val futures = blockManagerInfo.values.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove RDD $rddId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { From dd59a4be3641d0029ba6b1759837e1d1b9f8a840 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 7 Dec 2017 21:08:15 +0800 Subject: [PATCH 371/712] [SPARK-22712][SQL] Use `buildReaderWithPartitionValues` in native OrcFileFormat ## What changes were proposed in this pull request? To support vectorization in native OrcFileFormat later, we need to use `buildReaderWithPartitionValues` instead of `buildReader` like ParquetFileFormat. This PR replaces `buildReader` with `buildReaderWithPartitionValues`. ## How was this patch tested? Pass the Jenkins with the existing test cases. Author: Dongjoon Hyun Closes #19907 from dongjoon-hyun/SPARK-ORC-BUILD-READER. --- .../execution/datasources/orc/OrcFileFormat.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 75c42213db3c8..f7471cd7debce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -124,7 +125,7 @@ class OrcFileFormat true } - override def buildReader( + override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -167,9 +168,17 @@ class OrcFileFormat val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - val unsafeProjection = UnsafeProjection.create(requiredSchema) + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } From fc294463007c184715d604c718c0e913c83969f3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 7 Dec 2017 21:18:27 +0800 Subject: [PATCH 372/712] [SPARK-22699][SQL] GenerateSafeProjection should not use global variables for struct ## What changes were proposed in this pull request? GenerateSafeProjection is defining a mutable state for each struct, which is not needed. This is bad for the well known issues related to constant pool limits. The PR replace the global variable with a local one. ## How was this patch tested? added UT Author: Marco Gaido Closes #19914 from mgaido91/SPARK-22699. --- .../codegen/GenerateSafeProjection.scala | 18 ++++++++---------- .../codegen/GeneratedProjectionSuite.scala | 11 +++++++++++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 44e7148e5d98f..3dcbb518ba42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,8 +49,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") - // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -66,15 +64,15 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions( expressions = fieldWriters, funcName = "writeFields", - arguments = Seq("InternalRow" -> tmpInput) + arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) - val code = s""" - final InternalRow $tmpInput = $input; - $values = new Object[${schema.length}]; - $allFields - final InternalRow $output = new $rowClass($values); - $values = null; - """ + val code = + s""" + |final InternalRow $tmpInput = $input; + |final Object[] $values = new Object[${schema.length}]; + |$allFields + |final InternalRow $output = new $rowClass($values); + """.stripMargin ExprCode(code, "false", output) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 0cd0d8859145f..6031bdf19e957 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -208,4 +208,15 @@ class GeneratedProjectionSuite extends SparkFunSuite { unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) assert(row.getStruct(0, 1).getString(0).toString == "a") } + + test("SPARK-22699: GenerateSafeProjection should not use global variables for struct") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", IntegerType), true))) + val globalVariables = safeProj.getClass.getDeclaredFields + // We need always 3 variables: + // - one is a reference to this + // - one is the references object + // - one is the mutableRow + assert(globalVariables.length == 3) + } } From b79071910e61a851931b8b96d4acb66e0576caa7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 7 Dec 2017 21:24:36 +0800 Subject: [PATCH 373/712] [SPARK-22696][SQL] objects functions should not use unneeded global variables ## What changes were proposed in this pull request? Some objects functions are using global variables which are not needed. This can generate some unneeded entries in the constant pool. The PR replaces the unneeded global variables with local variables. ## How was this patch tested? added UTs Author: Marco Gaido Author: Marco Gaido Closes #19908 from mgaido91/SPARK-22696. --- .../expressions/objects/objects.scala | 61 +++++++++++-------- .../expressions/CodeGenerationSuite.scala | 16 ++++- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 730b2ff96da6c..349afece84d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1106,27 +1106,31 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values) val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin } - val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes) - val schemaField = ctx.addReferenceObj("schema", schema) + val childrenCode = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenCodes, + funcName = "createExternalRow", + extraArguments = "Object[]" -> values :: Nil) + val schemaField = ctx.addReferenceMinorObj(schema) - val code = s""" - $values = new Object[${children.size}]; - $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); - """ + val code = + s""" + |Object[] $values = new Object[${children.size}]; + |$childrenCode + |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + """.stripMargin ev.copy(code = code, isNull = "false") } } @@ -1244,25 +1248,28 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance) val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) s""" - ${fieldGen.code} - ${javaBeanInstance}.$setterMethod(${fieldGen.value}); - """ + |${fieldGen.code} + |$javaBeanInstance.$setterMethod(${fieldGen.value}); + """.stripMargin } - val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq) + val initializeCode = ctx.splitExpressionsWithCurrentInputs( + expressions = initialize.toSeq, + funcName = "initializeJavaBean", + extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = s""" - ${instanceGen.code} - ${javaBeanInstance} = ${instanceGen.value}; - if (!${instanceGen.isNull}) { - $initializeCode - } - """ + val code = + s""" + |${instanceGen.code} + |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; + |if (!${instanceGen.isNull}) { + | $initializeCode + |} + """.stripMargin ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a4198f826cedb..40bf29bb3b573 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -380,4 +380,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-22696: CreateExternalRow should not use global variables") { + val ctx = new CodegenContext + val schema = new StructType().add("a", IntegerType).add("b", StringType) + CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + + test("SPARK-22696: InitializeJavaBean should not use global variables") { + val ctx = new CodegenContext + InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } From f41c0a93fd3913ad93e55ddbfd875229872ecc97 Mon Sep 17 00:00:00 2001 From: kellyzly Date: Thu, 7 Dec 2017 10:04:04 -0600 Subject: [PATCH 374/712] [SPARK-22660][BUILD] Use position() and limit() to fix ambiguity issue in scala-2.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …a-2.12 and JDK9 ## What changes were proposed in this pull request? Some compile error after upgrading to scala-2.12 ```javascript spark_source/core/src/main/scala/org/apache/spark/executor/Executor.scala:455: ambiguous reference to overloaded definition, method limit in class ByteBuffer of type (x$1: Int)java.nio.ByteBuffer method limit in class Buffer of type ()Int match expected type ? val resultSize = serializedDirectResult.limit error ``` The limit method was moved from ByteBuffer to the superclass Buffer and it can no longer be called without (). The same reason for position method. ```javascript /home/zly/prj/oss/jdk9_HOS_SOURCE/spark_source/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala:427: ambiguous reference to overloaded definition, [error] both method putAll in class Properties of type (x$1: java.util.Map[_, _])Unit [error] and method putAll in class Hashtable of type (x$1: java.util.Map[_ <: Object, _ <: Object])Unit [error] match argument types (java.util.Map[String,String]) [error] props.putAll(outputSerdeProps.toMap.asJava) [error] ^ ``` This is because the key type is Object instead of String which is unsafe. ## How was this patch tested? running tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: kellyzly Closes #19854 from kellyzly/SPARK-22660. --- .../org/apache/spark/broadcast/TorrentBroadcast.scala | 3 ++- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 4 ++-- .../org/apache/spark/util/ByteBufferInputStream.scala | 2 +- .../org/apache/spark/util/io/ChunkedByteBuffer.scala | 4 ++-- .../org/apache/spark/serializer/KryoSerializerSuite.scala | 2 +- .../scala/org/apache/spark/storage/DiskStoreSuite.scala | 2 +- .../org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 4 +++- .../sql/execution/columnar/NullableColumnAccessor.scala | 2 +- .../columnar/compression/compressionSchemes.scala | 2 +- .../sql/hive/execution/ScriptTransformationExec.scala | 8 ++++++-- .../apache/spark/streaming/dstream/RawInputDStream.scala | 2 +- 12 files changed, 22 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 67e993c7f02e2..7aecd3c9668ea 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -99,7 +99,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def calcChecksum(block: ByteBuffer): Int = { val adler = new Adler32() if (block.hasArray) { - adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + adler.update(block.array, block.arrayOffset + block.position(), block.limit() + - block.position()) } else { val bytes = new Array[Byte](block.remaining()) block.duplicate.get(bytes) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e3e555eaa0277..af0a0ab656564 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -452,7 +452,7 @@ private[spark] class Executor( // TODO: do not serialize value twice val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) - val resultSize = serializedDirectResult.limit + val resultSize = serializedDirectResult.limit() // directSend = sending directly back to the driver val serializedResult: ByteBuffer = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e6982f333d521..4d75063fbf1c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -287,13 +287,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { val serializedTask = TaskDescription.encode(task) - if (serializedTask.limit >= maxRpcMessageSize) { + if (serializedTask.limit() >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + "spark.rpc.message.maxSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) + msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize) taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index a938cb07724c7..a5ee0ff16b5df 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -54,7 +54,7 @@ class ByteBufferInputStream(private var buffer: ByteBuffer) override def skip(bytes: Long): Long = { if (buffer != null) { val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) + buffer.position(buffer.position() + amountToSkip) if (buffer.remaining() == 0) { cleanUp() } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index c28570fb24560..7367af7888bd8 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -65,7 +65,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { for (bytes <- getChunks()) { while (bytes.remaining() > 0) { val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position + ioSize) + bytes.limit(bytes.position() + ioSize) channel.write(bytes) } } @@ -206,7 +206,7 @@ private[spark] class ChunkedByteBufferInputStream( override def skip(bytes: Long): Long = { if (currentChunk != null) { val amountToSkip = math.min(bytes, currentChunk.remaining).toInt - currentChunk.position(currentChunk.position + amountToSkip) + currentChunk.position(currentChunk.position() + amountToSkip) if (currentChunk.remaining() == 0) { if (chunks.hasNext) { currentChunk = chunks.next() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index eaec098b8d785..fc78655bf52ec 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -199,7 +199,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit < 100) + assert(ser.serialize(t).limit() < 100) } check(1 to 1000000) check(1 to 1000000 by 2) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 7258fdf5efc0d..efdd02fff7871 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -118,7 +118,7 @@ class DiskStoreSuite extends SparkFunSuite { val chunks = chunkedByteBuffer.chunks assert(chunks.size === 2) for (chunk <- chunks) { - assert(chunk.limit === 10 * 1024) + assert(chunk.limit() === 10 * 1024) } val e = intercept[IllegalArgumentException]{ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 2df8352f48660..08db4d827e400 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -296,7 +296,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") props.put("offsets.topic.num.partitions", "1") - props.putAll(withBrokerProps.asJava) + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + withBrokerProps.foreach { case (k, v) => props.put(k, v) } props } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 2f09757aa341c..341ade1a5c613 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -35,7 +35,7 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor { nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 pos = 0 - underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4) + underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4) super.initialize() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index bf00ad997c76e..79dcf3a6105ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -112,7 +112,7 @@ private[columnar] case object PassThrough extends CompressionScheme { var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity var pos = 0 var seenNulls = 0 - var bufferPos = buffer.position + var bufferPos = buffer.position() while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index d786a610f1535..3328400b214fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -412,7 +412,9 @@ case class HiveScriptIOSchema ( propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() - properties.putAll(propsMap.asJava) + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + propsMap.foreach { case (k, v) => properties.put(k, v) } serde.initialize(null, properties) serde @@ -424,7 +426,9 @@ case class HiveScriptIOSchema ( recordReaderClass.map { klass => val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] val props = new Properties() - props.putAll(outputSerdeProps.toMap.asJava) + // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } instance.initialize(inputStream, conf, props) instance } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index b2ec33e82ddaa..b22bbb79a5cc9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -98,7 +98,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) /** Read a buffer fully from a given Channel */ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) { - while (dest.position < dest.limit) { + while (dest.position() < dest.limit()) { if (channel.read(dest) == -1) { throw new EOFException("End of channel") } From 18b75d465b7563de926c5690094086a72a75c09f Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 7 Dec 2017 10:24:49 -0800 Subject: [PATCH 375/712] [SPARK-22719][SQL] Refactor ConstantPropagation ## What changes were proposed in this pull request? The current time complexity of ConstantPropagation is O(n^2), which can be slow when the query is complex. Refactor the implementation with O( n ) time complexity, and some pruning to avoid traversing the whole `Condition` ## How was this patch tested? Unit test. Also simple benchmark test in ConstantPropagationSuite ``` val condition = (1 to 500).map{_ => Rand(0) === Rand(0)}.reduce(And) val query = testRelation .select(columnA) .where(condition) val start = System.currentTimeMillis() (1 to 40).foreach { _ => Optimize.execute(query.analyze) } val end = System.currentTimeMillis() println(end - start) ``` Run time before changes: 18989ms (474ms per loop) Run time after changes: 1275 ms (32ms per loop) Author: Wang Gengliang Closes #19912 from gengliangwang/ConstantPropagation. --- .../sql/catalyst/optimizer/expressions.scala | 106 ++++++++++++------ 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 785e815b41185..6305b6c84bae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] { * }}} * * Approach used: - * - Start from AND operator as the root - * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they - * don't have a `NOT` or `OR` operator in them * - Populate a mapping of attribute => constant value by looking at all the equals predicates * - Using this mapping, replace occurrence of the attributes with the corresponding constant values * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { - private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { - case _: Not | _: Or => true - case _ => false - }.isDefined - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f: Filter => f transformExpressionsUp { - case and: And => - val conjunctivePredicates = - splitConjunctivePredicates(and) - .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) - .filterNot(expr => containsNonConjunctionPredicates(expr)) - - val equalityPredicates = conjunctivePredicates.collect { - case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) - } + case f: Filter => + val (newCondition, _) = traverse(f.condition, replaceChildren = true) + if (newCondition.isDefined) { + f.copy(condition = newCondition.get) + } else { + f + } + } - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet + type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - def replaceConstants(expression: Expression) = expression transform { - case a: AttributeReference => - constantsMap.get(a) match { - case Some(literal) => literal - case None => a - } + /** + * Traverse a condition as a tree and replace attributes with constant values. + * - On matching [[And]], recursively traverse each children and get propagated mappings. + * If the current node is not child of another [[And]], replace all occurrences of the + * attributes with the corresponding constant values. + * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping + * of attribute => constant. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. + * - Otherwise, stop traversal and propagate empty mapping. + * @param condition condition to be traversed + * @param replaceChildren whether to replace attributes with constant values in children + * @return A tuple including: + * 1. Option[Expression]: optional changed condition after traversal + * 2. EqualityPredicates: propagated mapping of attribute => constant + */ + private def traverse(condition: Expression, replaceChildren: Boolean) + : (Option[Expression], EqualityPredicates) = + condition match { + case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e))) + case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e))) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => + (None, Seq(((left, right), e))) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => + (None, Seq(((right, left), e))) + case a: And => + val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false) + val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false) + val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight + val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { + Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), + replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) + } else { + if (newLeft.isDefined || newRight.isDefined) { + Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) + } else { + None + } } - - and transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e) + (newSelf, equalityPredicates) + case o: Or => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = traverse(o.left, replaceChildren = true) + val (newRight, _) = traverse(o.right, replaceChildren = true) + val newSelf = if (newLeft.isDefined || newRight.isDefined) { + Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) + } else { + None } + (newSelf, Seq.empty) + case n: Not => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = traverse(n.child, replaceChildren = true) + (newChild.map(Not), Seq.empty) + case _ => (None, Seq.empty) + } + + private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) + : Expression = { + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + def replaceConstants0(expression: Expression) = expression transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } + condition transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) } } } From aa1764ba1addbe7ad79344d5640bf6426267a38c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 7 Dec 2017 15:45:23 -0800 Subject: [PATCH 376/712] [SPARK-22279][SQL] Turn on spark.sql.hive.convertMetastoreOrc by default ## What changes were proposed in this pull request? Like Parquet, this PR aims to turn on `spark.sql.hive.convertMetastoreOrc` by default. ## How was this patch tested? Pass all the existing test cases. Author: Dongjoon Hyun Closes #19499 from dongjoon-hyun/SPARK-22279. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index f5e6720f6a510..c489690af8cd1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -109,7 +109,7 @@ private[spark] object HiveUtils extends Logging { .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 0ba8f4b21167fe98a9f06036c4fc4ccd8e3287b4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 8 Dec 2017 09:52:16 +0800 Subject: [PATCH 377/712] [SPARK-21787][SQL] Support for pushing down filters for DateType in native OrcFileFormat ## What changes were proposed in this pull request? This PR support for pushing down filters for DateType in ORC ## How was this patch tested? Pass the Jenkins with newly add and updated test cases. Author: Dongjoon Hyun Closes #18995 from dongjoon-hyun/SPARK-21787. --- .../datasources/orc/OrcFilters.scala | 3 +- .../datasources/orc/OrcFilterSuite.scala | 29 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index cec256cc1b498..4f44ae4fa1d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -82,8 +82,7 @@ private[orc] object OrcFilters { * Both CharType and VarcharType are cleaned at AstBuilder. */ private def isSearchableType(dataType: DataType) = dataType match { - // TODO: SPARK-21787 Support for pushing down filters for DateType in ORC - case BinaryType | DateType => false + case BinaryType => false case _: AtomicType => true case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a5f6b68ee862e..8680b86517b19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -316,6 +316,30 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { } } + test("filter pushdown - date") { + val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => + Date.valueOf(day) + } + withOrcDataFrame(dates.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === dates(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < dates(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= dates(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(dates(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(dates(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(dates(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(dates(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(dates(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(dates(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + test("no filter pushdown - non-supported types") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) @@ -328,11 +352,6 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkNoFilterPredicate('_1 <=> 1.b) } - // DateType - val stringDate = "2015-01-01" - withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => - checkNoFilterPredicate('_1 === Date.valueOf(stringDate)) - } // MapType withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => checkNoFilterPredicate('_1.isNotNull) From b11869bc3b10b1ee50b1752cfcf4e736ef110cb1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 7 Dec 2017 22:02:51 -0800 Subject: [PATCH 378/712] [SPARK-22187][SS][REVERT] Revert change in state row format for mapGroupsWithState ## What changes were proposed in this pull request? #19416 changed the format in which rows were encoded in the state store. However, this can break existing streaming queries with the old format in unpredictable ways (potentially crashing the JVM). Hence I am reverting this for now. This will be re-applied in the future after we start saving more metadata in checkpoints to signify which version of state row format the existing streaming query is running. Then we can decode old and new formats accordingly. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #19924 from tdas/SPARK-22187-1. --- .../FlatMapGroupsWithStateExec.scala | 135 +++++++++++++--- .../FlatMapGroupsWithState_StateManager.scala | 153 ------------------ .../FlatMapGroupsWithStateSuite.scala | 130 +++++++-------- 3 files changed, 171 insertions(+), 247 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 29f38fab3f896..80769d728b8f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -60,8 +62,27 @@ case class FlatMapGroupsWithStateExec( import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout - val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) - val watermarkPresent = child.output.exists { + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = stateEncoder.resolveAndBind().deserializer + + private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } @@ -92,11 +113,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateManager.stateSchema, + stateAttributes.toStructType, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val processor = new InputProcessor(store) + val updater = new StateStoreUpdater(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -111,7 +132,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - processor.processNewData(filteredIter) ++ processor.processTimedOutState() + updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -126,7 +147,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class InputProcessor(store: StateStore) { + class StateStoreUpdater(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -135,6 +156,14 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + // Converters for translating state between rows and Java objects + private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateAttributes) + private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Index of the additional metadata fields in the state row + private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) + // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -143,19 +172,20 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - stateManager.getState(store, keyUnsafeRow), + keyUnsafeRow, valueRowIter, + store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def processTimedOutState(): Iterator[InternalRow] = { + def updateStateForTimedOutKeys(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -164,11 +194,12 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = stateManager.getAllState(store).filter { state => - state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) + timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { stateData => - callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) + timingOutKeys.flatMap { rowPair => + callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty } @@ -178,19 +209,22 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param stateData All the data related to the state to be updated + * @param keyRow Row representing the key, cannot be null * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty + * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - stateData: FlatMapGroupsWithState_StateData, + keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow], + prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(stateData.keyRow) // convert key to objects + val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val groupState = GroupStateImpl.createForStreaming( - Option(stateData.stateObj), + val stateObj = getStateObj(prevStateRow) + val keyedState = GroupStateImpl.createForStreaming( + Option(stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -198,24 +232,50 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { - stateManager.removeState(store, stateData.keyRow) + + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp + // If the state has not yet been set but timeout has been set, then + // we have to generate a row to save the timeout. However, attempting serialize + // null using case class encoder throws - + // java.lang.NullPointerException: Null value appeared in non-nullable field: + // If the schema is inferred from a Scala tuple / case class, or a Java bean, please + // try to use scala.Option[_] or other nullable types. + if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { + throw new IllegalStateException( + "Cannot set timeout when state is not defined, that is, state has not been" + + "initialized or has been removed") + } + + if (keyedState.hasRemoved) { + store.remove(keyRow) numUpdatedStateRows += 1 + } else { - val currentTimeoutTimestamp = groupState.getTimeoutTimestamp - val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp - val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged + val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) + val stateRowToWrite = if (keyedState.hasUpdated) { + getStateRow(keyedState.get) + } else { + prevStateRow + } + + val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp + val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged if (shouldWriteState) { - val updatedStateObj = if (groupState.exists) groupState.get else null - stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) + if (stateRowToWrite == null) { + // This should never happen because checks in GroupStateImpl should avoid cases + // where empty state would need to be written + throw new IllegalStateException("Attempting to write empty state") + } + setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) + store.put(keyRow, stateRowToWrite) numUpdatedStateRows += 1 } } @@ -224,5 +284,28 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } + + /** Returns the state as Java object if defined */ + def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null + } + + /** Returns the row for an updated state */ + def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { + if (isTimeoutEnabled && stateRow != null) { + stateRow.getLong(timeoutTimestampIndex) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala deleted file mode 100644 index e49546830286b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} -import org.apache.spark.sql.execution.ObjectOperator -import org.apache.spark.sql.execution.streaming.GroupStateImpl -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP -import org.apache.spark.sql.types.{IntegerType, LongType, StructType} - - -/** - * Class to serialize/write/read/deserialize state for - * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. - */ -class FlatMapGroupsWithState_StateManager( - stateEncoder: ExpressionEncoder[Any], - shouldStoreTimestamp: Boolean) extends Serializable { - - /** Schema of the state rows saved in the state store */ - val stateSchema = { - val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) - if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema - } - - /** Get deserialized state and corresponding timeout timestamp for a key */ - def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = { - val stateRow = store.get(keyRow) - stateDataForGets.withNew( - keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) - } - - /** Put state and timeout timestamp for a key */ - def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { - val stateRow = getStateRow(state) - setTimestamp(stateRow, timestamp) - store.put(keyRow, stateRow) - } - - /** Removed all information related to a key */ - def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { - store.remove(keyRow) - } - - /** Get all the keys and corresponding state rows in the state store */ - def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = { - val stateDataForGetAllState = FlatMapGroupsWithState_StateData() - store.getRange(None, None).map { pair => - stateDataForGetAllState.withNew( - pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) - } - } - - // Ordinals of the information stored in the state row - private lazy val nestedStateOrdinal = 0 - private lazy val timeoutTimestampOrdinal = 1 - - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val nestedStateExpr = CreateNamedStruct( - stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) - if (shouldStoreTimestamp) { - Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) - } else { - Seq(nestedStateExpr) - } - } - - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = { - val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) - val deser = stateEncoder.resolveAndBind().deserializer.transformUp { - case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) - } - CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser) - } - - // Converters for translating state between rows and Java objects - private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateSchema.toAttributes) - private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Reusable instance for returning state information - private lazy val stateDataForGets = FlatMapGroupsWithState_StateData() - - /** Returns the state as Java object if defined */ - private def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow == null) null - else getStateObjFromRow(stateRow) - } - - /** Returns the row for an updated state */ - private def getStateRow(obj: Any): UnsafeRow = { - val row = getStateRowFromObj(obj) - if (obj == null) { - row.setNullAt(nestedStateOrdinal) - } - row - } - - /** Returns the timeout timestamp of a state row is set */ - private def getTimestamp(stateRow: UnsafeRow): Long = { - if (shouldStoreTimestamp && stateRow != null) { - stateRow.getLong(timeoutTimestampOrdinal) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) - } -} - -/** - * Class to capture deserialized state and timestamp return by the state manager. - * This is intended for reuse. - */ -case class FlatMapGroupsWithState_StateData( - var keyRow: UnsafeRow = null, - var stateRow: UnsafeRow = null, - var stateObj: Any = null, - var timeoutTimestamp: Long = -1) { - def withNew( - newKeyRow: UnsafeRow, - newStateRow: UnsafeRow, - newStateObj: Any, - newTimeout: Long): this.type = { - keyRow = newKeyRow - stateRow = newStateRow - stateObj = newStateObj - timeoutTimestamp = newTimeout - this - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b906393a379ae..de2b51678cea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -360,13 +360,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing InputProcessor + // Values used for testing StateStoreUpdater val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for InputProcessor.processNewData() when timeout = NoTimeout + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -397,7 +397,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -444,18 +444,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } - // Tests with ProcessingTimeTimeout - if (priorState == None) { - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None, - expectedTimeoutTimestamp = currentBatchTimestamp + 5000) - } - testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -466,36 +454,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - timeout updated after state removed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = priorState, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None, - expectedTimeoutTimestamp = currentBatchTimestamp + 5000) - - // Tests with EventTimeTimeout - - if (priorState == None) { - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { - state.setTimeoutTimestamp(10000) - }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None, - expectedTimeoutTimestamp = 10000) - } - testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { - state.update(5); state.setTimeoutTimestamp(5000) - }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -514,23 +476,50 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { - state.remove(); state.setTimeoutTimestamp(10000) - }, - timeoutConf = EventTimeTimeout, - priorState = priorState, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None, - expectedTimeoutTimestamp = 10000) + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update } } - // Tests for InputProcessor.processTimedOutState() + // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), + // Try to remove these cases in the future + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + val testName = + if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + } + + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -1034,7 +1023,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"InputProcessor - process new data - $testName") { + test(s"StateStoreUpdater - updates with data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1056,7 +1045,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"InputProcessor - process timed out state - $testName") { + test(s"StateStoreUpdater - updates for timeout - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1083,20 +1072,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) - val stateManager = mapGroupsSparkPlan.stateManager + val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { - stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) + if (priorState.nonEmpty) { + val row = updater.getStateRow(priorState.get) + updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) + store.put(key.copy(), row.copy()) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - inputProcessor.processTimedOutState() + updater.updateStateForTimedOutKeys() } else { - inputProcessor.processNewData(Iterator(key)) + updater.updateStateForKeysWithData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1107,11 +1097,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedState = stateManager.getState(store, key) - assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, + val updatedStateRow = store.get(key) + assert( + Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, "final state not as expected") - assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") + if (updatedStateRow != null) { + assert( + updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } } } From f88a67bf08a66bd0291c11776a0c7c3928738ed9 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Fri, 8 Dec 2017 14:48:19 +0800 Subject: [PATCH 379/712] [SPARK-22452][SQL] Add getDouble to DataSourceV2Options - Implemented getDouble method in DataSourceV2Options - Add unit test Author: Sunitha Kambhampati Closes #19921 from skambha/ds2. --- .../spark/sql/sources/v2/DataSourceV2Options.java | 9 +++++++++ .../sql/sources/v2/DataSourceV2OptionsSuite.scala | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java index b2c908dc73a61..e98c04517c3db 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -80,4 +80,13 @@ public long getLong(String key, long defaultValue) { Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; } + /** + * Returns the double value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive + */ + public double getDouble(String key, double defaultValue) { + String lcaseKey = toLowerCase(key); + return keyLowerCasedMap.containsKey(lcaseKey) ? + Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala index 752d3c193cc74..90d92864b26fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -68,4 +68,15 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { options.getLong("foo", 0L) } } + + test("getDouble") { + val options = new DataSourceV2Options(Map("numFoo" -> "922337.1", + "foo" -> "bar").asJava) + assert(options.getDouble("numFOO", 0d) == 922337.1d) + assert(options.getDouble("numFoo2", -1.02d) == -1.02d) + + intercept[NumberFormatException]{ + options.getDouble("foo", 0.1d) + } + } } From f28b1a4c41d41e069b882bac06a68b2c45eb156c Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 8 Dec 2017 12:19:45 +0100 Subject: [PATCH 380/712] [SPARK-22721] BytesToBytesMap peak memory not updated. ## What changes were proposed in this pull request? Follow-up to earlier commit. The peak memory of BytesToBytesMap is not updated in more places - spill() and destructiveIterator(). ## How was this patch tested? Manually. Author: Juliusz Sompolski Closes #19923 from juliuszsompolski/SPARK-22721cd. --- .../main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7fdcf22c45f73..5f0045507aaab 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -347,6 +347,8 @@ public long spill(long numBytes) throws IOException { return 0L; } + updatePeakMemoryUsed(); + // TODO: use existing ShuffleWriteMetrics ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); @@ -424,6 +426,7 @@ public MapIterator iterator() { * `lookup()`, the behavior of the returned iterator is undefined. */ public MapIterator destructiveIterator() { + updatePeakMemoryUsed(); return new MapIterator(numValues, loc, true); } From 26e66453decf40ed6d590498aadbbf442bb90622 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 8 Dec 2017 20:44:21 +0900 Subject: [PATCH 381/712] [SPARK-22655][PYSPARK] Throw exception rather than exit silently in PythonRunner when Spark session is stopped ## What changes were proposed in this pull request? During Spark shutdown, if there are some active tasks, sometimes they will complete with incorrect results. The issue is in PythonRunner where it is returning partial result instead of throwing exception during Spark shutdown. This patch makes it so that these tasks fail instead of complete with partial results. ## How was this patch tested? Existing tests. Author: Li Jin Closes #19852 from icexelloss/python-runner-shutdown. --- .../main/scala/org/apache/spark/api/python/PythonRunner.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f524de68fbce0..93d508c28ebba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -317,10 +317,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( logDebug("Exception thrown after task interruption", e) throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null.asInstanceOf[OUT] // exit silently - case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) logError("This may have been caused by a prior exception:", writerThread.exception.get) From e4639fa68f479ccd7c365d3cf2bc2e93e24f1a00 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Fri, 8 Dec 2017 14:17:50 -0800 Subject: [PATCH 382/712] =?UTF-8?q?[SPARK-21672][CORE]=20Remove=20SHS-spec?= =?UTF-8?q?ific=20application=20/=20attempt=20data=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …structures ## What changes were proposed in this pull request? In general, the SHS pages now use the public API types to represent applications. Some internal code paths still used its own view of what applications and attempts look like (`ApplicationHistoryInfo` and `ApplicationAttemptInfo`), declared in ApplicationHistoryProvider.scala. This pull request removes these classes and updates the rest of the code to use `status.api.v1.ApplicationInfo` and `status.api.v1.ApplicationAttemptInfo` instead. Furthermore `status.api.v1.ApplicationInfo` and `status.api.v1.ApplicationAttemptInfo` were changed to case class to - facilitate copying instances - equality checking in test code - nicer toString() To simplify the code a bit `v1.` prefixes were also removed from occurrences of v1.ApplicationInfo and v1.ApplicationAttemptInfo as there is no more ambiguity between classes in history and status.api.v1. ## How was this patch tested? By running existing automated tests. Author: Sandor Murakozi Closes #19920 from smurakozi/SPARK-21672. --- .../history/ApplicationHistoryProvider.scala | 30 +++------------- .../deploy/history/FsHistoryProvider.scala | 32 ++++++----------- .../spark/deploy/history/HistoryPage.scala | 8 ++++- .../spark/deploy/history/HistoryServer.scala | 8 ++--- .../spark/status/AppStatusListener.scala | 8 ++--- .../api/v1/ApplicationListResource.scala | 32 ----------------- .../org/apache/spark/status/api/v1/api.scala | 34 +++++++++---------- .../history/FsHistoryProviderSuite.scala | 13 ++++--- .../deploy/history/HistoryServerSuite.scala | 9 +++-- 9 files changed, 63 insertions(+), 111 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 38f0d6f2afa5e..f1c06205bf04c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -23,31 +23,9 @@ import java.util.zip.ZipOutputStream import scala.xml.Node import org.apache.spark.SparkException +import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.ui.SparkUI -private[spark] case class ApplicationAttemptInfo( - attemptId: Option[String], - startTime: Long, - endTime: Long, - lastUpdated: Long, - sparkUser: String, - completed: Boolean = false, - appSparkVersion: String) - -private[spark] case class ApplicationHistoryInfo( - id: String, - name: String, - attempts: List[ApplicationAttemptInfo]) { - - /** - * Has this application completed? - * @return true if the most recent attempt has completed - */ - def completed: Boolean = { - attempts.nonEmpty && attempts.head.completed - } -} - /** * A loaded UI for a Spark application. * @@ -119,7 +97,7 @@ private[history] abstract class ApplicationHistoryProvider { * * @return List of all know applications. */ - def getListing(): Iterator[ApplicationHistoryInfo] + def getListing(): Iterator[ApplicationInfo] /** * Returns the Spark UI for a specific application. @@ -152,9 +130,9 @@ private[history] abstract class ApplicationHistoryProvider { def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit /** - * @return the [[ApplicationHistoryInfo]] for the appId if it exists. + * @return the [[ApplicationInfo]] for the appId if it exists. */ - def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] + def getApplicationInfo(appId: String): Option[ApplicationInfo] /** * @return html text to display when the application list is empty diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6a83c106f6d84..a8e1becc56ab7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -43,7 +43,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ -import org.apache.spark.status.api.v1 +import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -252,19 +252,19 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getListing(): Iterator[ApplicationHistoryInfo] = { + override def getListing(): Iterator[ApplicationInfo] = { // Return the listing in end time descending order. listing.view(classOf[ApplicationInfoWrapper]) .index("endTime") .reverse() .iterator() .asScala - .map(_.toAppHistoryInfo()) + .map(_.toApplicationInfo()) } - override def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] = { + override def getApplicationInfo(appId: String): Option[ApplicationInfo] = { try { - Some(load(appId).toAppHistoryInfo()) + Some(load(appId).toApplicationInfo()) } catch { case e: NoSuchElementException => None @@ -795,24 +795,16 @@ private[history] case class LogInfo( fileSize: Long) private[history] class AttemptInfoWrapper( - val info: v1.ApplicationAttemptInfo, + val info: ApplicationAttemptInfo, val logPath: String, val fileSize: Long, val adminAcls: Option[String], val viewAcls: Option[String], val adminAclsGroups: Option[String], - val viewAclsGroups: Option[String]) { - - def toAppAttemptInfo(): ApplicationAttemptInfo = { - ApplicationAttemptInfo(info.attemptId, info.startTime.getTime(), - info.endTime.getTime(), info.lastUpdated.getTime(), info.sparkUser, - info.completed, info.appSparkVersion) - } - -} + val viewAclsGroups: Option[String]) private[history] class ApplicationInfoWrapper( - val info: v1.ApplicationInfo, + val info: ApplicationInfo, val attempts: List[AttemptInfoWrapper]) { @JsonIgnore @KVIndexParam @@ -824,9 +816,7 @@ private[history] class ApplicationInfoWrapper( @JsonIgnore @KVIndexParam("oldestAttempt") def oldestAttempt(): Long = attempts.map(_.info.lastUpdated.getTime()).min - def toAppHistoryInfo(): ApplicationHistoryInfo = { - ApplicationHistoryInfo(info.id, info.name, attempts.map(_.toAppAttemptInfo())) - } + def toApplicationInfo(): ApplicationInfo = info.copy(attempts = attempts.map(_.info)) } @@ -883,7 +873,7 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends var memoryPerExecutorMB: Option[Int] = None def toView(): ApplicationInfoWrapper = { - val apiInfo = new v1.ApplicationInfo(id, name, coresGranted, maxCores, coresPerExecutor, + val apiInfo = ApplicationInfo(id, name, coresGranted, maxCores, coresPerExecutor, memoryPerExecutorMB, Nil) new ApplicationInfoWrapper(apiInfo, List(attempt.toView())) } @@ -906,7 +896,7 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends var viewAclsGroups: Option[String] = None def toView(): AttemptInfoWrapper = { - val apiInfo = new v1.ApplicationAttemptInfo( + val apiInfo = ApplicationAttemptInfo( attemptId, startTime, endTime, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6399dccc1676a..d3dd58996a0bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -21,6 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node +import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { @@ -30,7 +31,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean - val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val allAppsSize = parent.getApplicationList() + .count(isApplicationCompleted(_) != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() @@ -88,4 +90,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") private def makePageLink(showIncomplete: Boolean): String = { UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) } + + private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { + appInfo.attempts.nonEmpty && appInfo.attempts.head.completed + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index b822a48e98e91..75484f5c9f30f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{ShutdownHookManager, SystemClock, Utils} @@ -175,7 +175,7 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList(): Iterator[ApplicationHistoryInfo] = { + def getApplicationList(): Iterator[ApplicationInfo] = { provider.getListing() } @@ -188,11 +188,11 @@ class HistoryServer( } def getApplicationInfoList: Iterator[ApplicationInfo] = { - getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + getApplicationList() } def getApplicationInfo(appId: String): Option[ApplicationInfo] = { - provider.getApplicationInfo(appId).map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + provider.getApplicationInfo(appId) } override def writeEventLogs( diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 9c23d9d8c923a..6da44cbc44c4d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -77,7 +77,7 @@ private[spark] class AppStatusListener( override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { assert(event.appId.isDefined, "Application without IDs are not supported.") - val attempt = new v1.ApplicationAttemptInfo( + val attempt = v1.ApplicationAttemptInfo( event.appAttemptId, new Date(event.time), new Date(-1), @@ -87,7 +87,7 @@ private[spark] class AppStatusListener( false, sparkVersion) - appInfo = new v1.ApplicationInfo( + appInfo = v1.ApplicationInfo( event.appId.get, event.appName, None, @@ -122,7 +122,7 @@ private[spark] class AppStatusListener( override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { val old = appInfo.attempts.head - val attempt = new v1.ApplicationAttemptInfo( + val attempt = v1.ApplicationAttemptInfo( old.attemptId, old.startTime, new Date(event.time), @@ -132,7 +132,7 @@ private[spark] class AppStatusListener( true, old.appSparkVersion) - appInfo = new v1.ApplicationInfo( + appInfo = v1.ApplicationInfo( appInfo.id, appInfo.name, None, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 91660a524ca93..69054f2b771f1 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -20,8 +20,6 @@ import java.util.{Date, List => JList} import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam} import javax.ws.rs.core.MediaType -import org.apache.spark.deploy.history.ApplicationHistoryInfo - @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class ApplicationListResource extends ApiRequestContext { @@ -67,33 +65,3 @@ private[v1] class ApplicationListResource extends ApiRequestContext { startTimeOk && endTimeOk } } - -private[spark] object ApplicationsListResource { - def appHistoryInfoToPublicAppInfo(app: ApplicationHistoryInfo): ApplicationInfo = { - new ApplicationInfo( - id = app.id, - name = app.name, - coresGranted = None, - maxCores = None, - coresPerExecutor = None, - memoryPerExecutorMB = None, - attempts = app.attempts.map { internalAttemptInfo => - new ApplicationAttemptInfo( - attemptId = internalAttemptInfo.attemptId, - startTime = new Date(internalAttemptInfo.startTime), - endTime = new Date(internalAttemptInfo.endTime), - duration = - if (internalAttemptInfo.endTime > 0) { - internalAttemptInfo.endTime - internalAttemptInfo.startTime - } else { - 0 - }, - lastUpdated = new Date(internalAttemptInfo.lastUpdated), - sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed, - appSparkVersion = internalAttemptInfo.appSparkVersion - ) - } - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 14280099f6422..45eaf935fb083 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -24,27 +24,27 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus -class ApplicationInfo private[spark]( - val id: String, - val name: String, - val coresGranted: Option[Int], - val maxCores: Option[Int], - val coresPerExecutor: Option[Int], - val memoryPerExecutorMB: Option[Int], - val attempts: Seq[ApplicationAttemptInfo]) +case class ApplicationInfo private[spark]( + id: String, + name: String, + coresGranted: Option[Int], + maxCores: Option[Int], + coresPerExecutor: Option[Int], + memoryPerExecutorMB: Option[Int], + attempts: Seq[ApplicationAttemptInfo]) @JsonIgnoreProperties( value = Array("startTimeEpoch", "endTimeEpoch", "lastUpdatedEpoch"), allowGetters = true) -class ApplicationAttemptInfo private[spark]( - val attemptId: Option[String], - val startTime: Date, - val endTime: Date, - val lastUpdated: Date, - val duration: Long, - val sparkUser: String, - val completed: Boolean = false, - val appSparkVersion: String) { +case class ApplicationAttemptInfo private[spark]( + attemptId: Option[String], + startTime: Date, + endTime: Date, + lastUpdated: Date, + duration: Long, + sparkUser: String, + completed: Boolean = false, + appSparkVersion: String) { def getStartTimeEpoch: Long = startTime.getTime diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 86c8cdf43258c..84ee01c7f5aaf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.history import java.io._ import java.nio.charset.StandardCharsets +import java.util.Date import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} @@ -42,6 +43,7 @@ import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -114,9 +116,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc end: Long, lastMod: Long, user: String, - completed: Boolean): ApplicationHistoryInfo = { - ApplicationHistoryInfo(id, name, - List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed, ""))) + completed: Boolean): ApplicationInfo = { + + val duration = if (end > 0) end - start else 0 + new ApplicationInfo(id, name, None, None, None, None, + List(ApplicationAttemptInfo(None, new Date(start), + new Date(end), new Date(lastMod), duration, user, completed, ""))) } // For completed files, lastUpdated would be lastModified time. @@ -667,7 +672,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc * } */ private def updateAndCheck(provider: FsHistoryProvider) - (checkFn: Seq[ApplicationHistoryInfo] => Unit): Unit = { + (checkFn: Seq[ApplicationInfo] => Unit): Unit = { provider.checkForLogs() provider.cleanLogs() checkFn(provider.getListing().toSeq) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index d22a19e8af74a..4e4395d0fb959 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -45,6 +45,7 @@ import org.scalatest.selenium.WebBrowser import org.apache.spark._ import org.apache.spark.deploy.history.config._ +import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -505,6 +506,10 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getAppUI.store.jobsList(List(JobExecutionStatus.RUNNING).asJava) } + def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { + appInfo.attempts.nonEmpty && appInfo.attempts.head.completed + } + activeJobs() should have size 0 completedJobs() should have size 1 getNumJobs("") should be (1) @@ -537,7 +542,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(4 === getNumJobsRestful(), s"two jobs back-to-back not updated, server=$server\n") } val jobcount = getNumJobs("/jobs") - assert(!provider.getListing().next.completed) + assert(!isApplicationCompleted(provider.getListing().next)) listApplications(false) should contain(appId) @@ -545,7 +550,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers resetSparkContext() // check the app is now found as completed eventually(stdTimeout, stdInterval) { - assert(provider.getListing().next.completed, + assert(isApplicationCompleted(provider.getListing().next), s"application never completed, server=$server\n") } From acf7ef3154e094875fa89f30a78ab111b267db91 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 9 Dec 2017 11:53:15 +0900 Subject: [PATCH 383/712] [SPARK-12297][SQL] Adjust timezone for int96 data from impala ## What changes were proposed in this pull request? Int96 data written by impala vs data written by hive & spark is stored slightly differently -- they use a different offset for the timezone. This adds an option "spark.sql.parquet.int96TimestampConversion" (false by default) to adjust timestamps if and only if the writer is impala (or more precisely, if the parquet file's "createdBy" metadata does not start with "parquet-mr"). This matches the existing behavior in hive from HIVE-9482. ## How was this patch tested? Unit test added, existing tests run via jenkins. Author: Imran Rashid Author: Henry Robinson Closes #19769 from squito/SPARK-12297_skip_conversion. --- .../sql/catalyst/util/DateTimeUtils.scala | 1 + .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../parquet/VectorizedColumnReader.java | 59 +++++++-- .../VectorizedParquetRecordReader.java | 18 ++- .../parquet/ParquetFileFormat.scala | 30 ++++- .../parquet/ParquetReadSupport.scala | 15 ++- .../parquet/ParquetRecordMaterializer.scala | 7 +- .../parquet/ParquetRowConverter.scala | 11 +- .../resources/test-data/impala_timestamp.parq | Bin 0 -> 374 bytes .../ParquetInteroperabilitySuite.scala | 114 ++++++++++++++++++ 10 files changed, 237 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/impala_timestamp.parq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 746c3e8950f7b..b1ed25645b36c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -61,6 +61,7 @@ object DateTimeUtils { final val YearZero = -17999 final val toYearZero = to2001 + 7304850 final val TimeZoneGMT = TimeZone.getTimeZone("GMT") + final val TimeZoneUTC = TimeZone.getTimeZone("UTC") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) val TIMEZONE_OPTION = "timeZone" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce9cc562b220f..1121444cc938a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -291,6 +291,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_INT96_TIMESTAMP_CONVERSION = buildConf("spark.sql.parquet.int96TimestampConversion") + .doc("This controls whether timestamp adjustments should be applied to INT96 data when " + + "converting to timestamps, for data written by Impala. This is necessary because Impala " + + "stores INT96 data with a different timezone offset than Hive & Spark.") + .booleanConf + .createWithDefault(false) + object ParquetOutputTimestampType extends Enumeration { val INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value } @@ -1206,6 +1213,8 @@ class SQLConf extends Serializable with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT96TimestampConversion: Boolean = getConf(PARQUET_INT96_TIMESTAMP_CONVERSION) + def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) def parquetOutputTimestampType: ParquetOutputTimestampType.Value = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 71ca8b1b96a98..b7646969bcf3d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; @@ -93,13 +94,18 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; private final OriginalType originalType; + // The timezone conversion to apply to int96 timestamps. Null if no conversion. + private final TimeZone convertTz; + private final static TimeZone UTC = DateTimeUtils.TimeZoneUTC(); public VectorizedColumnReader( ColumnDescriptor descriptor, OriginalType originalType, - PageReader pageReader) throws IOException { + PageReader pageReader, + TimeZone convertTz) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + this.convertTz = convertTz; this.originalType = originalType; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); @@ -222,6 +228,10 @@ void readBatch(int total, WritableColumnVector column) throws IOException { } } + private boolean shouldConvertTimestamps() { + return convertTz != null && !convertTz.equals(UTC); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -294,11 +304,21 @@ private void decodeDictionaryIds( break; case INT96: if (column.dataType() == DataTypes.TimestampType) { - for (int i = rowId; i < rowId + num; ++i) { - // TODO: Convert dictionary of Binaries to dictionary of Longs - if (!column.isNullAt(i)) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); - column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + if (!shouldConvertTimestamps()) { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + } + } + } else { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); + long rawTime = ParquetRowConverter.binaryToSQLTimestamp(v); + long adjTime = DateTimeUtils.convertTz(rawTime, convertTz, UTC); + column.putLong(i, adjTime); + } } } } else { @@ -428,13 +448,26 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { - for (int i = 0; i < num; i++) { - if (defColumn.readInteger() == maxDefLevel) { - column.putLong(rowId + i, - // Read 12 bytes for INT96 - ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12))); - } else { - column.putNull(rowId + i); + if (!shouldConvertTimestamps()) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + // Read 12 bytes for INT96 + long rawTime = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); + column.putLong(rowId + i, rawTime); + } else { + column.putNull(rowId + i); + } + } + } else { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + // Read 12 bytes for INT96 + long rawTime = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); + long adjTime = DateTimeUtils.convertTz(rawTime, convertTz, UTC); + column.putLong(rowId + i, adjTime); + } else { + column.putNull(rowId + i); + } } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 669d71e60779d..14f2a58d638f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.TimeZone; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; @@ -77,6 +78,12 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private boolean[] missingColumns; + /** + * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to workaround + * incompatibilities between different engines when writing timestamp values. + */ + private TimeZone convertTz = null; + /** * columnBatch object that is used for batch decoding. This is created on first use and triggers * batched decoding. It is not valid to interleave calls to the batched interface with the row @@ -105,10 +112,15 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; } + public VectorizedParquetRecordReader(boolean useOffHeap) { + this(null, useOffHeap); + } + /** * Implementation of RecordReader API. */ @@ -291,8 +303,8 @@ private void checkEndOfRowGroup() throws IOException { columnReaders = new VectorizedColumnReader[columns.size()]; for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; - columnReaders[i] = new VectorizedColumnReader( - columns.get(i), types.get(i).getOriginalType(), pages.getPageReader(columns.get(i))); + columnReaders[i] = new VectorizedColumnReader(columns.get(i), types.get(i).getOriginalType(), + pages.getPageReader(columns.get(i)), convertTz); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2b1064955a777..45bedf70f975c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf @@ -307,6 +308,9 @@ class ParquetFileFormat hadoopConf.set( ParquetWriteSupport.SPARK_ROW_SCHEMA, requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) @@ -345,6 +349,8 @@ class ParquetFileFormat resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) val enableRecordFilter: Boolean = sparkSession.sessionState.conf.parquetRecordFilterEnabled + val timestampConversion: Boolean = + sparkSession.sessionState.conf.isParquetINT96TimestampConversion // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -363,6 +369,22 @@ class ParquetFileFormat fileSplit.getLocations, null) + val sharedConf = broadcastedHadoopConf.value.value + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + def isCreatedByParquetMr(): Boolean = { + val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) + footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") + } + val convertTz = + if (timestampConversion && !isCreatedByParquetMr()) { + Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) @@ -374,8 +396,8 @@ class ParquetFileFormat } val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { - val vectorizedReader = - new VectorizedParquetRecordReader(enableOffHeapColumnVector && taskContext.isDefined) + val vectorizedReader = new VectorizedParquetRecordReader( + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -388,9 +410,9 @@ class ParquetFileFormat // ParquetRecordReader returns UnsafeRow val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport, parquetFilter) + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter) } else { - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } reader.initialize(split, hadoopAttemptContext) reader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 2854cb1bc0c25..40ce5d5e0564e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Map => JMap} +import java.util.{Map => JMap, TimeZone} import scala.collection.JavaConverters._ @@ -48,9 +48,17 @@ import org.apache.spark.sql.types._ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. */ -private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Logging { +private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) + extends ReadSupport[UnsafeRow] with Logging { private var catalystRequestedSchema: StructType = _ + def this() { + // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only + // used in the vectorized reader, where we get the convertTz value directly, and the value here + // is ignored. + this(None) + } + /** * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record * readers. Responsible for figuring out Parquet requested schema used for column pruning. @@ -95,7 +103,8 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetToSparkSchemaConverter(conf)) + new ParquetToSparkSchemaConverter(conf), + convertTz) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 793755e9aaeb5..b2459dd0e8bba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.TimeZone + import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -33,11 +35,12 @@ import org.apache.spark.sql.types.StructType private[parquet] class ParquetRecordMaterializer( parquetSchema: MessageType, catalystSchema: StructType, - schemaConverter: ParquetToSparkSchemaConverter) + schemaConverter: ParquetToSparkSchemaConverter, + convertTz: Option[TimeZone]) extends RecordMaterializer[UnsafeRow] { private val rootConverter = - new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, NoopUpdater) + new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, convertTz, NoopUpdater) override def getCurrentRecord: UnsafeRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 10f6c3b4f15e3..1199725941842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -117,12 +118,14 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined * types should have been expanded. + * @param convertTz the optional time zone to convert to for int96 data * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( schemaConverter: ParquetToSparkSchemaConverter, parquetType: GroupType, catalystType: StructType, + convertTz: Option[TimeZone], updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -151,6 +154,8 @@ private[parquet] class ParquetRowConverter( |${catalystType.prettyJson} """.stripMargin) + private val UTC = DateTimeUtils.TimeZoneUTC + /** * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates * converted filed values to the `ordinal`-th cell in `currentRow`. @@ -279,7 +284,9 @@ private[parquet] class ParquetRowConverter( val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) val timeOfDayNanos = buf.getLong val julianDay = buf.getInt - updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + val rawTime = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + val adjTime = convertTz.map(DateTimeUtils.convertTz(rawTime, _, UTC)).getOrElse(rawTime) + updater.setLong(adjTime) } } @@ -309,7 +316,7 @@ private[parquet] class ParquetRowConverter( case t: StructType => new ParquetRowConverter( - schemaConverter, parquetType.asGroupType(), t, new ParentContainerUpdater { + schemaConverter, parquetType.asGroupType(), t, convertTz, new ParentContainerUpdater { override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) diff --git a/sql/core/src/test/resources/test-data/impala_timestamp.parq b/sql/core/src/test/resources/test-data/impala_timestamp.parq new file mode 100644 index 0000000000000000000000000000000000000000..21e5318db98c9d0e328bd7c00999a2d1d4bce86d GIT binary patch literal 374 zcmWG=3^EjD5%m!D@eyScWno}Y>0vlDDB$vr3=ARJK(^Zj-M^d+ z4EJ(W8AKUGMMNcZKw5wpsDMj_iGhKEjgg62g+Xlvql_qs2UdB$a07 zq$pUJ8z!ctrJ1FerJ9;0nVP4h8XFs!TBf8X8>N`1r5anL8X23JC8cOe%E$n{avF#O HfFTY5EiXyU literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 9dc56292c3720..e3edafa9c84e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -19,7 +19,15 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName + import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext { @@ -87,4 +95,110 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS Row(Seq(2, 3)))) } } + + test("parquet timestamp conversion") { + // Make a table with one parquet file written by impala, and one parquet file written by spark. + // We should only adjust the timestamps in the impala file, and only if the conf is set + val impalaFile = "test-data/impala_timestamp.parq" + + // here are the timestamps in the impala file, as they were saved by impala + val impalaFileData = + Seq( + "2001-01-01 01:01:01", + "2002-02-02 02:02:02", + "2003-03-03 03:03:03" + ).map(java.sql.Timestamp.valueOf) + val impalaPath = Thread.currentThread().getContextClassLoader.getResource(impalaFile) + .toURI.getPath + withTempPath { tableDir => + val ts = Seq( + "2004-04-04 04:04:04", + "2005-05-05 05:05:05", + "2006-06-06 06:06:06" + ).map { s => java.sql.Timestamp.valueOf(s) } + import testImplicits._ + // match the column names of the file from impala + val df = spark.createDataset(ts).toDF().repartition(1).withColumnRenamed("value", "ts") + df.write.parquet(tableDir.getAbsolutePath) + FileUtils.copyFile(new File(impalaPath), new File(tableDir, "part-00001.parq")) + + Seq(false, true).foreach { int96TimestampConversion => + Seq(false, true).foreach { vectorized => + withSQLConf( + (SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + SQLConf.ParquetOutputTimestampType.INT96.toString), + (SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key, int96TimestampConversion.toString()), + (SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized.toString()) + ) { + val readBack = spark.read.parquet(tableDir.getAbsolutePath).collect() + assert(readBack.size === 6) + // if we apply the conversion, we'll get the "right" values, as saved by impala in the + // original file. Otherwise, they're off by the local timezone offset, set to + // America/Los_Angeles in tests + val impalaExpectations = if (int96TimestampConversion) { + impalaFileData + } else { + impalaFileData.map { ts => + DateTimeUtils.toJavaTimestamp(DateTimeUtils.convertTz( + DateTimeUtils.fromJavaTimestamp(ts), + DateTimeUtils.TimeZoneUTC, + DateTimeUtils.getTimeZone(conf.sessionLocalTimeZone))) + } + } + val fullExpectations = (ts ++ impalaExpectations).map(_.toString).sorted.toArray + val actual = readBack.map(_.getTimestamp(0).toString).sorted + withClue( + s"int96TimestampConversion = $int96TimestampConversion; vectorized = $vectorized") { + assert(fullExpectations === actual) + + // Now test that the behavior is still correct even with a filter which could get + // pushed down into parquet. We don't need extra handling for pushed down + // predicates because (a) in ParquetFilters, we ignore TimestampType and (b) parquet + // does not read statistics from int96 fields, as they are unsigned. See + // scalastyle:off line.size.limit + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L419 + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L348 + // scalastyle:on line.size.limit + // + // Just to be defensive in case anything ever changes in parquet, this test checks + // the assumption on column stats, and also the end-to-end behavior. + + val hadoopConf = sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { + override def accept(path: Path): Boolean = !path.getName.startsWith("_") + }) + // grab the meta data from the parquet file. The next section of asserts just make + // sure the test is configured correctly. + assert(parts.size == 2) + parts.foreach { part => + val oneFooter = + ParquetFileReader.readFooter(hadoopConf, part.getPath, NO_FILTER) + assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 1) + assert(oneFooter.getFileMetaData.getSchema.getColumns.get(0).getType() === + PrimitiveTypeName.INT96) + val oneBlockMeta = oneFooter.getBlocks().get(0) + val oneBlockColumnMeta = oneBlockMeta.getColumns().get(0) + val columnStats = oneBlockColumnMeta.getStatistics + // This is the important assert. Column stats are written, but they are ignored + // when the data is read back as mentioned above, b/c int96 is unsigned. This + // assert makes sure this holds even if we change parquet versions (if eg. there + // were ever statistics even on unsigned columns). + assert(columnStats.isEmpty) + } + + // These queries should return the entire dataset with the conversion applied, + // but if the predicates were applied to the raw values in parquet, they would + // incorrectly filter data out. + val query = spark.read.parquet(tableDir.getAbsolutePath) + .where("ts > '2001-01-01 01:00:00'") + val countWithFilter = query.count() + val exp = if (int96TimestampConversion) 6 else 5 + assert(countWithFilter === exp, query) + } + } + } + } + } + } } From 251b2c03b4b2846c9577390f87db3db511be7721 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 9 Dec 2017 20:20:28 +0900 Subject: [PATCH 384/712] [SPARK-22672][SQL][TEST][FOLLOWUP] Fix to use `spark.conf` ## What changes were proposed in this pull request? During https://github.com/apache/spark/pull/19882, `conf` is mistakenly used to switch ORC implementation between `native` and `hive`. To affect `OrcTest` correctly, `spark.conf` should be used. ## How was this patch tested? Pass the tests. Author: Dongjoon Hyun Closes #19931 from dongjoon-hyun/SPARK-22672-2. --- .../spark/sql/execution/datasources/orc/OrcTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index d94cb850ed2a2..38b34a03e3e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -51,12 +51,12 @@ abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAl protected override def beforeAll(): Unit = { super.beforeAll() - originalConfORCImplementation = conf.getConf(ORC_IMPLEMENTATION) - conf.setConf(ORC_IMPLEMENTATION, orcImp) + originalConfORCImplementation = spark.conf.get(ORC_IMPLEMENTATION) + spark.conf.set(ORC_IMPLEMENTATION.key, orcImp) } protected override def afterAll(): Unit = { - conf.setConf(ORC_IMPLEMENTATION, originalConfORCImplementation) + spark.conf.set(ORC_IMPLEMENTATION.key, originalConfORCImplementation) super.afterAll() } From ab1b6ee73157fa5401b189924263795bfad919c9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 9 Dec 2017 09:28:46 -0600 Subject: [PATCH 385/712] [BUILD] update release scripts ## What changes were proposed in this pull request? Change to dist.apache.org instead of home directory sha512 should have .sha512 extension. From ASF release signing doc: "The checksum SHOULD be generated using SHA-512. A .sha file SHOULD contain a SHA-1 checksum, for historical reasons." NOTE: I *think* should require some changes to work with Jenkins' release build ## How was this patch tested? manually Author: Felix Cheung Closes #19754 from felixcheung/releasescript. --- dev/create-release/release-build.sh | 122 ++++++++++++---------------- dev/create-release/release-tag.sh | 8 +- 2 files changed, 60 insertions(+), 70 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5b43f9bab7505..c71137468054f 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -23,23 +23,19 @@ usage: release-build.sh Creates build deliverables from a Spark commit. Top level targets are - package: Create binary packages and copy them to home.apache - docs: Build docs and copy them to home.apache + package: Create binary packages and commit them to dist.apache.org/repos/dist/dev/spark/ + docs: Build docs and commit them to dist.apache.org/repos/dist/dev/spark/ publish-snapshot: Publish snapshot release to Apache snapshots publish-release: Publish a release to Apache release repo All other inputs are environment variables GIT_REF - Release tag or commit to build from -SPARK_VERSION - Version of Spark being built (e.g. 2.1.2) SPARK_PACKAGE_VERSION - Release identifier in top level package directory (e.g. 2.1.2-rc1) -REMOTE_PARENT_DIR - Parent in which to create doc or release builds. -REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only - have this number of subdirectories (by deleting old ones). WARNING: This deletes data. +SPARK_VERSION - (optional) Version of Spark being built (e.g. 2.1.2) ASF_USERNAME - Username of ASF committer account ASF_PASSWORD - Password of ASF committer account -ASF_RSA_KEY - RSA private key file for ASF committer account GPG_KEY - GPG key used to sign release artifacts GPG_PASSPHRASE - Passphrase for GPG key @@ -57,7 +53,20 @@ if [[ $@ == *"help"* ]]; then exit_with_usage fi -for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do +if [[ -z "$ASF_PASSWORD" ]]; then + echo 'The environment variable ASF_PASSWORD is not set. Enter the password.' + echo + stty -echo && printf "ASF password: " && read ASF_PASSWORD && printf '\n' && stty echo +fi + +if [[ -z "$GPG_PASSPHRASE" ]]; then + echo 'The environment variable GPG_PASSPHRASE is not set. Enter the passphrase to' + echo 'unlock the GPG signing key that will be used to sign the release!' + echo + stty -echo && printf "GPG passphrase: " && read GPG_PASSPHRASE && printf '\n' && stty echo +fi + +for env in ASF_USERNAME GPG_PASSPHRASE GPG_KEY; do if [ -z "${!env}" ]; then echo "ERROR: $env must be set to run this script" exit_with_usage @@ -71,8 +80,7 @@ export LC_ALL=C # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} -# Destination directory parent on remote server -REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} +RELEASE_STAGING_LOCATION="https://dist.apache.org/repos/dist/dev/spark" GPG="gpg -u $GPG_KEY --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging @@ -142,42 +150,16 @@ if [ -z "$SPARK_PACKAGE_VERSION" ]; then SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" fi -DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" - -function LFTP { - SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" - COMMANDS=$(cat < \ spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION # Updated for each binary build @@ -238,7 +220,7 @@ if [[ "$1" == "package" ]]; then $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ - $R_DIST_NAME.sha + $R_DIST_NAME.sha512 else echo "Creating distribution with PIP package" ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ @@ -257,7 +239,7 @@ if [[ "$1" == "package" ]]; then $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.sha + $PYTHON_DIST_NAME.sha512 fi echo "Copying and signing regular binary distribution" @@ -270,7 +252,7 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.sha + spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } # TODO: Check exit codes of children here: @@ -284,22 +266,20 @@ if [[ "$1" == "package" ]]; then wait rm -rf spark-$SPARK_VERSION-bin-*/ - # Copy data - dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" - echo "Copying release tarballs to $dest_dir" - # Put to new directory: - LFTP mkdir -p $dest_dir || true - LFTP mput -O $dest_dir 'spark-*' - LFTP mput -O $dest_dir 'pyspark-*' - LFTP mput -O $dest_dir 'SparkR_*' - # Delete /latest directory and rename new upload to /latest - LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" - LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" - # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir || true - LFTP mput -O $dest_dir 'spark-*' - LFTP mput -O $dest_dir 'pyspark-*' - LFTP mput -O $dest_dir 'SparkR_*' + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark exit 0 fi @@ -307,21 +287,24 @@ if [[ "$1" == "docs" ]]; then # Documentation cd spark echo "Building Spark docs" - dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" cd docs # TODO: Make configurable to add this: PRODUCTION=1 PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build - echo "Copying release documentation to $dest_dir" - # Put to new directory: - LFTP mkdir -p $dest_dir || true - LFTP mirror -R _site $dest_dir - # Delete /latest directory and rename new upload to /latest - LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" - LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" - # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir || true - LFTP mirror -R _site $dest_dir cd .. + cd .. + + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark exit 0 fi @@ -399,6 +382,7 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index 370a62ce15bc4..a05716a5f66bb 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -41,6 +41,12 @@ if [[ $@ == *"help"* ]]; then exit_with_usage fi +if [[ -z "$ASF_PASSWORD" ]]; then + echo 'The environment variable ASF_PASSWORD is not set. Enter the password.' + echo + stty -echo && printf "ASF password: " && read ASF_PASSWORD && printf '\n' && stty echo +fi + for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do if [ -z "${!env}" ]; then echo "$env must be set to run this script" @@ -52,7 +58,7 @@ ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" MVN="build/mvn --force" rm -rf spark -git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH cd spark git config user.name "$GIT_NAME" From 4289ac9d8dbbc45fc2ee6d0250a2113107bf08d0 Mon Sep 17 00:00:00 2001 From: zouchenjun Date: Sun, 10 Dec 2017 20:36:14 -0800 Subject: [PATCH 386/712] [SPARK-22496][SQL] thrift server adds operation logs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? since hive 2.0+ upgrades log4j to log4j2,a lot of [changes](https://issues.apache.org/jira/browse/HIVE-11304) are made working on it. as spark is not to ready to update its inner hive version(1.2.1) , so I manage to make little changes. the function registerCurrentOperationLog is moved from SQLOperstion to its parent class ExecuteStatementOperation so spark can use it. ## How was this patch tested? manual test Author: zouchenjun Closes #19721 from ChenjunZou/operation-log. --- .../cli/operation/ExecuteStatementOperation.java | 13 +++++++++++++ .../hive/service/cli/operation/SQLOperation.java | 12 ------------ .../SparkExecuteStatementOperation.scala | 1 + 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java index 3f2de108f069a..dc7de3c49554b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessor; import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory; +import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationType; import org.apache.hive.service.cli.session.HiveSession; @@ -67,4 +68,16 @@ protected void setConfOverlay(Map confOverlay) { this.confOverlay = confOverlay; } } + + protected void registerCurrentOperationLog() { + if (isOperationLogEnabled) { + if (operationLog == null) { + LOG.warn("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()); + isOperationLogEnabled = false; + return; + } + OperationLog.setCurrentOperationLog(operationLog); + } + } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index 5014cedd870b6..fd9108eb53ca9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -274,18 +274,6 @@ private Hive getSessionHive() throws HiveSQLException { } } - private void registerCurrentOperationLog() { - if (isOperationLogEnabled) { - if (operationLog == null) { - LOG.warn("Failed to get current OperationLog object of Operation: " + - getHandle().getHandleIdentifier()); - isOperationLogEnabled = false; - return; - } - OperationLog.setCurrentOperationLog(operationLog); - } - } - private void cleanup(OperationState state) throws HiveSQLException { setState(state); if (shouldRunAsync()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f5191fa9132bd..664bc20601eaa 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -170,6 +170,7 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { + registerCurrentOperationLog() try { execute() } catch { From ec873a4fd20a47cf0791456bfb301f25a34ae014 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 11 Dec 2017 06:35:31 -0600 Subject: [PATCH 387/712] [SPARK-14516][FOLLOWUP] Adding ClusteringEvaluator to examples ## What changes were proposed in this pull request? In SPARK-14516 we have introduced ClusteringEvaluator, but we didn't put any reference in the documentation and the examples were still relying on the sum of squared errors to show a way to evaluate the clustering model. The PR adds the ClusteringEvaluator in the examples. ## How was this patch tested? Manual runs of the examples. Author: Marco Gaido Closes #19676 from mgaido91/SPARK-14516_examples. --- .../apache/spark/examples/ml/JavaKMeansExample.java | 12 +++++++++--- examples/src/main/python/ml/kmeans_example.py | 12 +++++++++--- .../org/apache/spark/examples/ml/KMeansExample.scala | 12 +++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index d8f948ae38cb3..dc4b0bcb59657 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; +import org.apache.spark.ml.evaluation.ClusteringEvaluator; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -51,9 +52,14 @@ public static void main(String[] args) { KMeans kmeans = new KMeans().setK(2).setSeed(1L); KMeansModel model = kmeans.fit(dataset); - // Evaluate clustering by computing Within Set Sum of Squared Errors. - double WSSSE = model.computeCost(dataset); - System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + // Make predictions + Dataset predictions = model.transform(dataset); + + // Evaluate clustering by computing Silhouette score + ClusteringEvaluator evaluator = new ClusteringEvaluator(); + + double silhouette = evaluator.evaluate(predictions); + System.out.println("Silhouette with squared euclidean distance = " + silhouette); // Shows the result. Vector[] centers = model.clusterCenters(); diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 6846ec4599714..5f77843e3743a 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -19,6 +19,7 @@ # $example on$ from pyspark.ml.clustering import KMeans +from pyspark.ml.evaluation import ClusteringEvaluator # $example off$ from pyspark.sql import SparkSession @@ -45,9 +46,14 @@ kmeans = KMeans().setK(2).setSeed(1) model = kmeans.fit(dataset) - # Evaluate clustering by computing Within Set Sum of Squared Errors. - wssse = model.computeCost(dataset) - print("Within Set Sum of Squared Errors = " + str(wssse)) + # Make predictions + predictions = model.transform(dataset) + + # Evaluate clustering by computing Silhouette score + evaluator = ClusteringEvaluator() + + silhouette = evaluator.evaluate(predictions) + print("Silhouette with squared euclidean distance = " + str(silhouette)) # Shows the result. centers = model.clusterCenters() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index a1d19e138dedb..2bc8184e623ff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -21,6 +21,7 @@ package org.apache.spark.examples.ml // $example on$ import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.ml.evaluation.ClusteringEvaluator // $example off$ import org.apache.spark.sql.SparkSession @@ -47,9 +48,14 @@ object KMeansExample { val kmeans = new KMeans().setK(2).setSeed(1L) val model = kmeans.fit(dataset) - // Evaluate clustering by computing Within Set Sum of Squared Errors. - val WSSSE = model.computeCost(dataset) - println(s"Within Set Sum of Squared Errors = $WSSSE") + // Make predictions + val predictions = model.transform(dataset) + + // Evaluate clustering by computing Silhouette score + val evaluator = new ClusteringEvaluator() + + val silhouette = evaluator.evaluate(predictions) + println(s"Silhouette with squared euclidean distance = $silhouette") // Shows the result. println("Cluster Centers: ") From 6cc7021a40b64c41a51f337ec4be9545a25e838c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 11 Dec 2017 21:52:57 +0800 Subject: [PATCH 388/712] [SPARK-22267][SQL][TEST] Spark SQL incorrectly reads ORC files when column order is different ## What changes were proposed in this pull request? Until 2.2.1, with the default configuration, Apache Spark returns incorrect results when ORC file schema is different from metastore schema order. This is due to Hive 1.2.1 library and some issues on `convertMetastoreOrc` option. ```scala scala> Seq(1 -> 2).toDF("c1", "c2").write.format("orc").mode("overwrite").save("/tmp/o") scala> sql("CREATE EXTERNAL TABLE o(c2 INT, c1 INT) STORED AS orc LOCATION '/tmp/o'") scala> spark.table("o").show // This is wrong. +---+---+ | c2| c1| +---+---+ | 1| 2| +---+---+ scala> spark.read.orc("/tmp/o").show // This is correct. +---+---+ | c1| c2| +---+---+ | 1| 2| +---+---+ ``` After [SPARK-22279](https://github.com/apache/spark/pull/19499), the default configuration doesn't have this bug. Although Hive 1.2.1 library code path still has the problem, we had better have a test coverage on what we have now in order to prevent future regression on it. ## How was this patch tested? Pass the Jenkins with a newly added test test. Author: Dongjoon Hyun Closes #19928 from dongjoon-hyun/SPARK-22267. --- .../sql/hive/execution/SQLQuerySuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a516646..f2562c33e2a6e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2153,4 +2153,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-22267 Spark SQL incorrectly reads ORC files when column order is different") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTempPath { f => + val path = f.getCanonicalPath + Seq(1 -> 2).toDF("c1", "c2").write.orc(path) + checkAnswer(spark.read.orc(path), Row(1, 2)) + + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 + withTable("t") { + sql(s"CREATE EXTERNAL TABLE t(c2 INT, c1 INT) STORED AS ORC LOCATION '$path'") + checkAnswer(spark.table("t"), Row(2, 1)) + } + } + } + } + } + } } From bf20abb2dc084b743289ea24fc8d6e47d6e6c0dd Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Mon, 11 Dec 2017 13:36:15 -0600 Subject: [PATCH 389/712] [SPARK-22642][SQL] the createdTempDir will not be deleted if an exception occurs, should delete it with try-finally. ## What changes were proposed in this pull request? We found staging directories will not be dropped sometimes in our production environment. The createdTempDir will not be deleted if an exception occurs, we should delete createdTempDir with try-finally. This PR is follow-up SPARK-18703. ## How was this patch tested? exist tests Author: zuotingbing Closes #19841 from zuotingbing/SPARK-stagedir. --- .../hive/execution/InsertIntoHiveTable.scala | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 56e10bc457a00..b46addb6aa85b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} -import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalog} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.CommandUtils @@ -91,6 +92,34 @@ case class InsertIntoHiveTable( ) val tableLocation = hiveQlTable.getDataLocation val tmpLocation = getExternalTmpPath(sparkSession, hadoopConf, tableLocation) + + try { + processInsert(sparkSession, externalCatalog, hadoopConf, tableDesc, tmpLocation) + } finally { + // Attempt to delete the staging directory and the inclusive files. If failed, the files are + // expected to be dropped at the normal termination of VM since deleteOnExit is used. + deleteExternalTmpPath(hadoopConf) + } + + // un-cache this table. + sparkSession.catalog.uncacheTable(table.identifier.quotedString) + sparkSession.sessionState.catalog.refreshTable(table.identifier) + + CommandUtils.updateTableStats(sparkSession, table) + + // It would be nice to just return the childRdd unchanged so insert operations could be chained, + // however for now we return an empty list to simplify compatibility checks with hive, which + // does not return anything for insert operations. + // TODO: implement hive compatibility as rules. + Seq.empty[Row] + } + + private def processInsert( + sparkSession: SparkSession, + externalCatalog: ExternalCatalog, + hadoopConf: Configuration, + tableDesc: TableDesc, + tmpLocation: Path): Unit = { val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -231,21 +260,5 @@ case class InsertIntoHiveTable( overwrite, isSrcLocal = false) } - - // Attempt to delete the staging directory and the inclusive files. If failed, the files are - // expected to be dropped at the normal termination of VM since deleteOnExit is used. - deleteExternalTmpPath(hadoopConf) - - // un-cache this table. - sparkSession.catalog.uncacheTable(table.identifier.quotedString) - sparkSession.sessionState.catalog.refreshTable(table.identifier) - - CommandUtils.updateTableStats(sparkSession, table) - - // It would be nice to just return the childRdd unchanged so insert operations could be chained, - // however for now we return an empty list to simplify compatibility checks with hive, which - // does not return anything for insert operations. - // TODO: implement hive compatibility as rules. - Seq.empty[Row] } } From a04f2bea67c9abf4149d33ac6c319cd4f85344d5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 11 Dec 2017 13:08:42 -0800 Subject: [PATCH 390/712] Revert "[SPARK-22496][SQL] thrift server adds operation logs" This reverts commit 4289ac9d8dbbc45fc2ee6d0250a2113107bf08d0. --- .../cli/operation/ExecuteStatementOperation.java | 13 ------------- .../hive/service/cli/operation/SQLOperation.java | 12 ++++++++++++ .../SparkExecuteStatementOperation.scala | 1 - 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java index dc7de3c49554b..3f2de108f069a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java @@ -23,7 +23,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessor; import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory; -import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationType; import org.apache.hive.service.cli.session.HiveSession; @@ -68,16 +67,4 @@ protected void setConfOverlay(Map confOverlay) { this.confOverlay = confOverlay; } } - - protected void registerCurrentOperationLog() { - if (isOperationLogEnabled) { - if (operationLog == null) { - LOG.warn("Failed to get current OperationLog object of Operation: " + - getHandle().getHandleIdentifier()); - isOperationLogEnabled = false; - return; - } - OperationLog.setCurrentOperationLog(operationLog); - } - } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index fd9108eb53ca9..5014cedd870b6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -274,6 +274,18 @@ private Hive getSessionHive() throws HiveSQLException { } } + private void registerCurrentOperationLog() { + if (isOperationLogEnabled) { + if (operationLog == null) { + LOG.warn("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()); + isOperationLogEnabled = false; + return; + } + OperationLog.setCurrentOperationLog(operationLog); + } + } + private void cleanup(OperationState state) throws HiveSQLException { setState(state); if (shouldRunAsync()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..f5191fa9132bd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -170,7 +170,6 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - registerCurrentOperationLog() try { execute() } catch { From c235b5f9772be57122ca5109b9295cf8c85926df Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Dec 2017 13:15:45 -0800 Subject: [PATCH 391/712] [SPARK-22746][SQL] Avoid the generation of useless mutable states by SortMergeJoin ## What changes were proposed in this pull request? This PR reduce the number of global mutable variables in generated code of `SortMergeJoin`. Before this PR, global mutable variables are used to extend lifetime of variables in the nested loop. This can be achieved by declaring variable at the outer most loop level where the variables are used. In the following example, `smj_value8`, `smj_value8`, and `smj_value9` are declared as local variable at lines 145-147 in `With this PR`. This PR fixes potential assertion error by #19865. Without this PR, a global mutable variable is potentially passed to arguments in generated code of split function. Without this PR ``` /* 010 */ int smj_value8; /* 011 */ boolean smj_value8; /* 012 */ int smj_value9; .. /* 143 */ protected void processNext() throws java.io.IOException { /* 144 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) { /* 145 */ boolean smj_loaded = false; /* 146 */ smj_isNull6 = smj_leftRow.isNullAt(1); /* 147 */ smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1)); /* 148 */ scala.collection.Iterator smj_iterator = smj_matches.generateIterator(); /* 149 */ while (smj_iterator.hasNext()) { /* 150 */ InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next(); /* 151 */ boolean smj_isNull8 = smj_rightRow1.isNullAt(1); /* 152 */ int smj_value11 = smj_isNull8 ? -1 : (smj_rightRow1.getInt(1)); /* 153 */ /* 154 */ boolean smj_value12 = (smj_isNull6 && smj_isNull8) || /* 155 */ (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11); /* 156 */ if (false || !smj_value12) continue; /* 157 */ if (!smj_loaded) { /* 158 */ smj_loaded = true; /* 159 */ smj_value8 = smj_leftRow.getInt(0); /* 160 */ } /* 161 */ int smj_value10 = smj_rightRow1.getInt(0); /* 162 */ smj_numOutputRows.add(1); /* 163 */ /* 164 */ smj_rowWriter.zeroOutNullBytes(); /* 165 */ /* 166 */ smj_rowWriter.write(0, smj_value8); /* 167 */ /* 168 */ if (smj_isNull6) { /* 169 */ smj_rowWriter.setNullAt(1); /* 170 */ } else { /* 171 */ smj_rowWriter.write(1, smj_value9); /* 172 */ } /* 173 */ /* 174 */ smj_rowWriter.write(2, smj_value10); /* 175 */ /* 176 */ if (smj_isNull8) { /* 177 */ smj_rowWriter.setNullAt(3); /* 178 */ } else { /* 179 */ smj_rowWriter.write(3, smj_value11); /* 180 */ } /* 181 */ append(smj_result.copy()); /* 182 */ /* 183 */ } /* 184 */ if (shouldStop()) return; /* 185 */ } /* 186 */ } ``` With this PR ``` /* 143 */ protected void processNext() throws java.io.IOException { /* 144 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) { /* 145 */ int smj_value8 = -1; /* 146 */ boolean smj_isNull6 = false; /* 147 */ int smj_value9 = -1; /* 148 */ boolean smj_loaded = false; /* 149 */ smj_isNull6 = smj_leftRow.isNullAt(1); /* 150 */ smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1)); /* 151 */ scala.collection.Iterator smj_iterator = smj_matches.generateIterator(); /* 152 */ while (smj_iterator.hasNext()) { /* 153 */ InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next(); /* 154 */ boolean smj_isNull8 = smj_rightRow1.isNullAt(1); /* 155 */ int smj_value11 = smj_isNull8 ? -1 : (smj_rightRow1.getInt(1)); /* 156 */ /* 157 */ boolean smj_value12 = (smj_isNull6 && smj_isNull8) || /* 158 */ (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11); /* 159 */ if (false || !smj_value12) continue; /* 160 */ if (!smj_loaded) { /* 161 */ smj_loaded = true; /* 162 */ smj_value8 = smj_leftRow.getInt(0); /* 163 */ } /* 164 */ int smj_value10 = smj_rightRow1.getInt(0); /* 165 */ smj_numOutputRows.add(1); /* 166 */ /* 167 */ smj_rowWriter.zeroOutNullBytes(); /* 168 */ /* 169 */ smj_rowWriter.write(0, smj_value8); /* 170 */ /* 171 */ if (smj_isNull6) { /* 172 */ smj_rowWriter.setNullAt(1); /* 173 */ } else { /* 174 */ smj_rowWriter.write(1, smj_value9); /* 175 */ } /* 176 */ /* 177 */ smj_rowWriter.write(2, smj_value10); /* 178 */ /* 179 */ if (smj_isNull8) { /* 180 */ smj_rowWriter.setNullAt(3); /* 181 */ } else { /* 182 */ smj_rowWriter.write(3, smj_value11); /* 183 */ } /* 184 */ append(smj_result.copy()); /* 185 */ /* 186 */ } /* 187 */ if (shouldStop()) return; /* 188 */ } /* 189 */ } ``` ## How was this patch tested? Existing test cases Author: Kazuaki Ishizaki Closes #19937 from kiszk/SPARK-22746. --- .../execution/joins/SortMergeJoinExec.scala | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 9c08ec71c1fde..554b73181116c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -507,32 +507,38 @@ case class SortMergeJoinExec( } /** - * Creates variables for left part of result row. + * Creates variables and declarations for left part of result row. * * In order to defer the access after condition and also only access once in the loop, * the variables should be declared separately from accessing the columns, we can't use the * codegen of BoundReference here. */ - private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value) + val javaType = ctx.javaType(a.dataType) + val defaultValue = ctx.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) val code = s""" |$isNull = $leftRow.isNullAt($i); - |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin - ExprCode(code, isNull, value) + val leftVarsDecl = + s""" + |boolean $isNull = false; + |$javaType $value = $defaultValue; + """.stripMargin + (ExprCode(code, isNull, value), leftVarsDecl) } else { - ExprCode(s"$value = $valueCode;", "false", value) + val code = s"$value = $valueCode;" + val leftVarsDecl = s"""$javaType $value = $defaultValue;""" + (ExprCode(code, "false", value), leftVarsDecl) } - } + }.unzip } /** @@ -580,7 +586,7 @@ case class SortMergeJoinExec( val (leftRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val leftVars = createLeftVars(ctx, leftRow) + val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) @@ -617,6 +623,7 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | ${leftVarDecl.mkString("\n")} | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { From 3f4060c340d6bac412e8819c4388ccba226efcf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 11 Dec 2017 15:14:59 -0800 Subject: [PATCH 392/712] [SPARK-22646][K8S] Spark on Kubernetes - basic submission client This PR contains implementation of the basic submission client for the cluster mode of Spark on Kubernetes. It's step 2 from the step-wise plan documented [here](https://github.com/apache-spark-on-k8s/spark/issues/441#issuecomment-330802935). This addition is covered by the [SPIP](http://apache-spark-developers-list.1001551.n3.nabble.com/SPIP-Spark-on-Kubernetes-td22147.html) vote which passed on Aug 31. This PR and #19468 together form a MVP of Spark on Kubernetes that allows users to run Spark applications that use resources locally within the driver and executor containers on Kubernetes 1.6 and up. Some changes on pom and build/test setup are copied over from #19468 to make this PR self contained and testable. The submission client is mainly responsible for creating the Kubernetes pod that runs the Spark driver. It follows a step-based approach to construct the driver pod, as the code under the `submit.steps` package shows. The steps are orchestrated by `DriverConfigurationStepsOrchestrator`. `Client` creates the driver pod and waits for the application to complete if it's configured to do so, which is the case by default. This PR also contains Dockerfiles of the driver and executor images. They are included because some of the environment variables set in the code would not make sense without referring to the Dockerfiles. * The patch contains unit tests which are passing. * Manual testing: ./build/mvn -Pkubernetes clean package succeeded. * It is a subset of the entire changelist hosted at http://github.com/apache-spark-on-k8s/spark which is in active use in several organizations. * There is integration testing enabled in the fork currently hosted by PepperData which is being moved over to RiseLAB CI. * Detailed documentation on trying out the patch in its entirety is in: https://apache-spark-on-k8s.github.io/userdocs/running-on-kubernetes.html cc rxin felixcheung mateiz (shepherd) k8s-big-data SIG members & contributors: mccheah foxish ash211 ssuchter varunkatta kimoonkim erikerlandson tnachen ifilonenko liyinan926 Author: Yinan Li Closes #19717 from liyinan926/spark-kubernetes-4. --- assembly/pom.xml | 10 + .../scala/org/apache/spark/SparkConf.scala | 6 +- .../scala/org/apache/spark/SparkContext.scala | 1 - .../org/apache/spark/deploy/SparkSubmit.scala | 48 +++- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../spark/internal/config/package.scala | 8 + .../scala/org/apache/spark/util/Utils.scala | 36 +++ .../org/apache/spark/SparkContextSuite.scala | 6 + .../spark/deploy/SparkSubmitSuite.scala | 27 ++ .../org/apache/spark/util/UtilsSuite.scala | 21 ++ docs/configuration.md | 20 ++ docs/running-on-yarn.md | 16 +- .../launcher/SparkSubmitOptionParser.java | 2 +- .../org/apache/spark/deploy/k8s/Config.scala | 85 ++++-- .../apache/spark/deploy/k8s/Constants.scala | 28 ++ ...DriverConfigurationStepsOrchestrator.scala | 125 +++++++++ .../submit/KubernetesClientApplication.scala | 240 +++++++++++++++++ .../k8s/submit/KubernetesDriverSpec.scala | 47 ++++ .../k8s/submit/KubernetesFileUtils.scala | 68 +++++ .../k8s/submit/LoggingPodStatusWatcher.scala | 180 +++++++++++++ .../deploy/k8s/submit/MainAppResource.scala | 21 ++ .../steps/BaseDriverConfigurationStep.scala | 165 ++++++++++++ .../steps/DependencyResolutionStep.scala | 62 +++++ .../steps/DriverConfigurationStep.scala | 30 +++ .../DriverKubernetesCredentialsStep.scala | 245 ++++++++++++++++++ .../steps/DriverServiceBootstrapStep.scala | 103 ++++++++ .../cluster/k8s/ExecutorPodFactory.scala | 13 +- .../k8s/KubernetesClusterManager.scala | 8 +- .../spark/deploy/k8s/submit/ClientSuite.scala | 234 +++++++++++++++++ ...rConfigurationStepsOrchestratorSuite.scala | 82 ++++++ .../BaseDriverConfigurationStepSuite.scala | 118 +++++++++ .../steps/DependencyResolutionStepSuite.scala | 81 ++++++ ...DriverKubernetesCredentialsStepSuite.scala | 153 +++++++++++ .../DriverServiceBootstrapStepSuite.scala | 180 +++++++++++++ .../src/main/dockerfiles/driver/Dockerfile | 34 +++ .../src/main/dockerfiles/executor/Dockerfile | 34 +++ .../main/dockerfiles/spark-base/Dockerfile | 47 ++++ .../main/dockerfiles/spark-base/entrypoint.sh | 37 +++ .../org/apache/spark/deploy/yarn/config.scala | 8 - 39 files changed, 2566 insertions(+), 67 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh diff --git a/assembly/pom.xml b/assembly/pom.xml index 01fe354235e5b..b3b4239771bc3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -148,6 +148,16 @@ + + kubernetes + + + org.apache.spark + spark-kubernetes_${scala.binary.version} + ${project.version} + + + hive diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 0e08ff65e4784..4b1286d91e8f3 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -668,7 +668,11 @@ private[spark] object SparkConf extends Logging { MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key -> Seq( AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")), LISTENER_BUS_EVENT_QUEUE_CAPACITY.key -> Seq( - AlternateConfig("spark.scheduler.listenerbus.eventqueue.size", "2.3")) + AlternateConfig("spark.scheduler.listenerbus.eventqueue.size", "2.3")), + DRIVER_MEMORY_OVERHEAD.key -> Seq( + AlternateConfig("spark.yarn.driver.memoryOverhead", "2.3")), + EXECUTOR_MEMORY_OVERHEAD.key -> Seq( + AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 71f1e7c7321bc..92e13ce1ba042 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -18,7 +18,6 @@ package org.apache.spark import java.io._ -import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cfcdce648d330..ab834bb682041 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -76,7 +76,8 @@ object SparkSubmit extends CommandLineUtils with Logging { private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES // Deploy modes private val CLIENT = 1 @@ -97,6 +98,8 @@ object SparkSubmit extends CommandLineUtils with Logging { "org.apache.spark.deploy.yarn.YarnClusterApplication" private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" // scalastyle:off println private[spark] def printVersionAndExit(): Unit = { @@ -257,9 +260,10 @@ object SparkSubmit extends CommandLineUtils with Logging { YARN case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, local") + printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -294,6 +298,16 @@ object SparkSubmit extends CommandLineUtils with Logging { } } + if (clusterManager == KUBERNETES) { + args.master = Utils.checkAndGetK8sMasterUrl(args.master) + // Make sure KUBERNETES is included in our build if we're trying to use it + if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { + printErrorAndExit( + "Could not load KUBERNETES classes. " + + "This copy of Spark may not have been compiled with KUBERNETES support.") + } + } + // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => @@ -302,6 +316,12 @@ object SparkSubmit extends CommandLineUtils with Logging { case (STANDALONE, CLUSTER) if args.isR => printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") + case (KUBERNETES, _) if args.isPython => + printErrorAndExit("Python applications are currently not supported for Kubernetes.") + case (KUBERNETES, _) if args.isR => + printErrorAndExit("R applications are currently not supported for Kubernetes.") + case (KUBERNETES, CLIENT) => + printErrorAndExit("Client mode is currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -322,6 +342,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER + val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -557,19 +578,19 @@ object SparkSubmit extends CommandLineUtils with Logging { OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.executor.cores"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.executor.memory"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, + OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, + OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, confKey = "spark.driver.supervise"), @@ -703,6 +724,19 @@ object SparkSubmit extends CommandLineUtils with Logging { } } + if (isKubernetesCluster) { + childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS + if (args.primaryResource != SparkLauncher.NO_RESOURCE) { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + } + childArgs ++= Array("--main-class", args.mainClass) + if (args.childArgs != null) { + args.childArgs.foreach { arg => + childArgs += ("--arg", arg) + } + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sparkConf.setIfMissing(k, v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index a7722e4f86023..9db7a1fe3106d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -515,8 +515,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println( s""" |Options: - | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local - | (Default: local[*]). + | --master MASTER_URL spark://host:port, mesos://host:port, yarn, + | k8s://https://host:port, or local (Default: local[*]). | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or | on one of the worker machines inside the cluster ("cluster") | (Default: client). diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8fa25c0281493..172ba85359da7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -41,6 +41,10 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .bytesConf(ByteUnit.MiB) + .createOptional + private[spark] val EVENT_LOG_COMPRESS = ConfigBuilder("spark.eventLog.compress") .booleanConf @@ -80,6 +84,10 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .bytesConf(ByteUnit.MiB) + .createOptional + private[spark] val MEMORY_OFFHEAP_ENABLED = ConfigBuilder("spark.memory.offHeap.enabled") .doc("If true, Spark will attempt to use off-heap memory for certain operations. " + "If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive.") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 51bf91614c866..1ed09dc489440 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2744,6 +2744,42 @@ private[spark] object Utils extends Logging { } } + /** + * Check the validity of the given Kubernetes master URL and return the resolved URL. Prefix + * "k8s:" is appended to the resolved URL as the prefix is used by KubernetesClusterManager + * in canCreate to determine if the KubernetesClusterManager should be used. + */ + def checkAndGetK8sMasterUrl(rawMasterURL: String): String = { + require(rawMasterURL.startsWith("k8s://"), + "Kubernetes master URL must start with k8s://.") + val masterWithoutK8sPrefix = rawMasterURL.substring("k8s://".length) + + // To handle master URLs, e.g., k8s://host:port. + if (!masterWithoutK8sPrefix.contains("://")) { + val resolvedURL = s"https://$masterWithoutK8sPrefix" + logInfo("No scheme specified for kubernetes master URL, so defaulting to https. Resolved " + + s"URL is $resolvedURL.") + return s"k8s:$resolvedURL" + } + + val masterScheme = new URI(masterWithoutK8sPrefix).getScheme + val resolvedURL = masterScheme.toLowerCase match { + case "https" => + masterWithoutK8sPrefix + case "http" => + logWarning("Kubernetes master URL uses HTTP instead of HTTPS.") + masterWithoutK8sPrefix + case null => + val resolvedURL = s"https://$masterWithoutK8sPrefix" + logInfo("No scheme specified for kubernetes master URL, so defaulting to https. Resolved " + + s"URL is $resolvedURL.") + resolvedURL + case _ => + throw new IllegalArgumentException("Invalid Kubernetes master scheme: " + masterScheme) + } + + return s"k8s:$resolvedURL" + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 2bde8757dae5d..37fcc93c62fa8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -550,6 +550,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("client mode with a k8s master url") { + intercept[SparkException] { + sc = new SparkContext("k8s://https://host:port", "test", new SparkConf()) + } + } + testCancellingTasks("that raise interrupted exception on cancel") { Thread.sleep(9999999) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index e200755e639e1..35594ec47c941 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -388,6 +388,33 @@ class SparkSubmitSuite conf.get("spark.ui.enabled") should be ("false") } + test("handles k8s cluster mode") { + val clArgs = Seq( + "--deploy-mode", "cluster", + "--master", "k8s://host:port", + "--executor-memory", "5g", + "--class", "org.SomeClass", + "--driver-memory", "4g", + "--conf", "spark.kubernetes.namespace=spark", + "--conf", "spark.kubernetes.driver.docker.image=bar", + "/home/thejar.jar", + "arg1") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + + val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap + childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) + childArgsMap.get("--main-class") should be (Some("org.SomeClass")) + childArgsMap.get("--arg") should be (Some("arg1")) + mainClass should be (KUBERNETES_CLUSTER_SUBMIT_CLASS) + classpath should have length (0) + conf.get("spark.master") should be ("k8s:https://host:port") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.kubernetes.namespace") should be ("spark") + conf.get("spark.kubernetes.driver.docker.image") should be ("bar") + } + test("handles confs with flag equivalents") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4d3adeb968e84..5c4e4ca0cded6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1146,6 +1146,27 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } } + test("check Kubernetes master URL") { + val k8sMasterURLHttps = Utils.checkAndGetK8sMasterUrl("k8s://https://host:port") + assert(k8sMasterURLHttps === "k8s:https://host:port") + + val k8sMasterURLHttp = Utils.checkAndGetK8sMasterUrl("k8s://http://host:port") + assert(k8sMasterURLHttp === "k8s:http://host:port") + + val k8sMasterURLWithoutScheme = Utils.checkAndGetK8sMasterUrl("k8s://127.0.0.1:8443") + assert(k8sMasterURLWithoutScheme === "k8s:https://127.0.0.1:8443") + + val k8sMasterURLWithoutScheme2 = Utils.checkAndGetK8sMasterUrl("k8s://127.0.0.1") + assert(k8sMasterURLWithoutScheme2 === "k8s:https://127.0.0.1") + + intercept[IllegalArgumentException] { + Utils.checkAndGetK8sMasterUrl("k8s:https://host:port") + } + + intercept[IllegalArgumentException] { + Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") + } + } } private class SimpleExtension diff --git a/docs/configuration.md b/docs/configuration.md index ef061dd39dcba..d70bac134808f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -157,6 +157,16 @@ of the most common options to set are: or in your default properties file. + + spark.driver.memoryOverhead + driverMemory * 0.10, with minimum of 384 + + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is + memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + This tends to grow with the container size (typically 6-10%). This option is currently supported + on YARN and Kubernetes. + + spark.executor.memory 1g @@ -164,6 +174,16 @@ of the most common options to set are: Amount of memory to use per executor process (e.g. 2g, 8g). + + spark.executor.memoryOverhead + executorMemory * 0.10, with minimum of 384 + + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that + accounts for things like VM overheads, interned strings, other native overheads, etc. This tends + to grow with the executor size (typically 6-10%). This option is currently supported on YARN and + Kubernetes. + + spark.extraListeners (none) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 9599d40c545b2..7e2386f33b583 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -227,25 +227,11 @@ To use a custom metrics.properties for the application master and executors, upd The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large. - - spark.yarn.executor.memoryOverhead - executorMemory * 0.10, with minimum of 384 - - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). - - - - spark.yarn.driver.memoryOverhead - driverMemory * 0.10, with minimum of 384 - - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). - - spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. + Same as spark.driver.memoryOverhead, but for the YARN Application Master in client mode. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 6767cc5079649..c57af92029460 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -82,7 +82,7 @@ class SparkSubmitOptionParser { * name of the option, passed to {@link #handle(String, String)}. *

    * Options not listed here nor in the "switch" list below will result in a call to - * {@link $#handleUnknown(String)}. + * {@link #handleUnknown(String)}. *

    * These two arrays are visible for tests. */ diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index f0742b91987b6..f35fb38798218 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s +import java.util.concurrent.TimeUnit + import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.network.util.ByteUnit @@ -24,12 +26,16 @@ private[spark] object Config extends Logging { val KUBERNETES_NAMESPACE = ConfigBuilder("spark.kubernetes.namespace") - .doc("The namespace that will be used for running the driver and executor pods. When using " + - "spark-submit in cluster mode, this can also be passed to spark-submit via the " + - "--kubernetes-namespace command line argument.") + .doc("The namespace that will be used for running the driver and executor pods.") .stringConf .createWithDefault("default") + val DRIVER_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.driver.docker.image") + .doc("Docker image to use for the driver. Specify this using the standard Docker tag format.") + .stringConf + .createOptional + val EXECUTOR_DOCKER_IMAGE = ConfigBuilder("spark.kubernetes.executor.docker.image") .doc("Docker image to use for the executors. Specify this using the standard Docker tag " + @@ -44,9 +50,9 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") - val APISERVER_AUTH_DRIVER_CONF_PREFIX = + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" - val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = "spark.kubernetes.authenticate.driver.mounted" val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" @@ -55,7 +61,7 @@ private[spark] object Config extends Logging { val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" val KUBERNETES_SERVICE_ACCOUNT_NAME = - ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + ConfigBuilder(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") .doc("Service account that is used when running the driver pod. The driver pod uses " + "this service account when requesting executor pods from the API server. If specific " + "credentials are given for the driver pod to use, the driver will favor " + @@ -63,19 +69,17 @@ private[spark] object Config extends Logging { .stringConf .createOptional - // Note that while we set a default for this when we start up the - // scheduler, the specific default value is dynamically determined - // based on the executor memory. - val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = - ConfigBuilder("spark.kubernetes.executor.memoryOverhead") - .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This " + - "is memory that accounts for things like VM overheads, interned strings, other native " + - "overheads, etc. This tends to grow with the executor size. (typically 6-10%).") - .bytesConf(ByteUnit.MiB) + val KUBERNETES_DRIVER_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.driver.limit.cores") + .doc("Specify the hard cpu limit for the driver pod") + .stringConf .createOptional - val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." - val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for each executor pod") + .stringConf + .createOptional val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") @@ -104,12 +108,6 @@ private[spark] object Config extends Logging { .checkValue(value => value > 0, "Allocation batch delay should be a positive integer") .createWithDefault(1) - val KUBERNETES_EXECUTOR_LIMIT_CORES = - ConfigBuilder("spark.kubernetes.executor.limit.cores") - .doc("Specify the hard cpu limit for a single executor pod") - .stringConf - .createOptional - val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS = ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts") .doc("Maximum number of attempts allowed for checking the reason of an executor loss " + @@ -119,5 +117,46 @@ private[spark] object Config extends Logging { "must be a positive integer") .createWithDefault(10) + val WAIT_FOR_APP_COMPLETION = + ConfigBuilder("spark.kubernetes.submission.waitAppCompletion") + .doc("In cluster mode, whether to wait for the application to finish before exiting the " + + "launcher process.") + .booleanConf + .createWithDefault(true) + + val REPORT_INTERVAL = + ConfigBuilder("spark.kubernetes.report.interval") + .doc("Interval between reports of the current app status in cluster mode.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") + .createWithDefaultString("1s") + + val JARS_DOWNLOAD_LOCATION = + ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") + .doc("Location to download jars to in the driver and executors. When using" + + " spark-submit, this directory must be empty and will be mounted as an empty directory" + + " volume on the driver and executor pod.") + .stringConf + .createWithDefault("/var/spark-data/spark-jars") + + val FILES_DOWNLOAD_LOCATION = + ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") + .doc("Location to download files to in the driver and executors. When using" + + " spark-submit, this directory must be empty and will be mounted as an empty directory" + + " volume on the driver and executor pods.") + .stringConf + .createWithDefault("/var/spark-data/spark-files") + + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = + "spark.kubernetes.authenticate.submission" + val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." + + val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." + val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + + val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 4ddeefb15a89d..0b91145405d3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,30 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" + // Annotations + val SPARK_APP_NAME_ANNOTATION = "spark-app-name" + + // Credentials secrets + val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = + "/mnt/secrets/spark-kubernetes-credentials" + val DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME = "ca-cert" + val DRIVER_CREDENTIALS_CA_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME" + val DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME = "client-key" + val DRIVER_CREDENTIALS_CLIENT_KEY_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME" + val DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME = "client-cert" + val DRIVER_CREDENTIALS_CLIENT_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME" + val DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME = "oauth-token" + val DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME" + val DRIVER_CREDENTIALS_SECRET_VOLUME_NAME = "kubernetes-credentials" + // Default and fixed ports val DEFAULT_DRIVER_PORT = 7078 val DEFAULT_BLOCKMANAGER_PORT = 7079 + val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" val EXECUTOR_PORT_NAME = "executor" @@ -42,9 +63,16 @@ private[spark] object Constants { val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" + val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" + val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" + val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" + val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN_MIB = 384L } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala new file mode 100644 index 0000000000000..c563fc5bfbadf --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import java.util.UUID + +import com.google.common.primitives.Longs + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.ConfigurationUtils +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.steps._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.util.SystemClock + +/** + * Constructs the complete list of driver configuration steps to run to deploy the Spark driver. + */ +private[spark] class DriverConfigurationStepsOrchestrator( + namespace: String, + kubernetesAppId: String, + launchTime: Long, + mainAppResource: Option[MainAppResource], + appName: String, + mainClass: String, + appArgs: Array[String], + submissionSparkConf: SparkConf) { + + // The resource name prefix is derived from the Spark application name, making it easy to connect + // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the + // application the user submitted. + private val kubernetesResourceNamePrefix = { + val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") + s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") + } + + private val dockerImagePullPolicy = submissionSparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val jarsDownloadPath = submissionSparkConf.get(JARS_DOWNLOAD_LOCATION) + private val filesDownloadPath = submissionSparkConf.get(FILES_DOWNLOAD_LOCATION) + + def getAllConfigurationSteps(): Seq[DriverConfigurationStep] = { + val driverCustomLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + submissionSparkConf, + KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + + val allDriverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> kubernetesAppId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + + val initialSubmissionStep = new BaseDriverConfigurationStep( + kubernetesAppId, + kubernetesResourceNamePrefix, + allDriverLabels, + dockerImagePullPolicy, + appName, + mainClass, + appArgs, + submissionSparkConf) + + val driverAddressStep = new DriverServiceBootstrapStep( + kubernetesResourceNamePrefix, + allDriverLabels, + submissionSparkConf, + new SystemClock) + + val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( + submissionSparkConf, kubernetesResourceNamePrefix) + + val additionalMainAppJar = if (mainAppResource.nonEmpty) { + val mayBeResource = mainAppResource.get match { + case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => + Some(resource) + case _ => None + } + mayBeResource + } else { + None + } + + val sparkJars = submissionSparkConf.getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty[String]) ++ + additionalMainAppJar.toSeq + val sparkFiles = submissionSparkConf.getOption("spark.files") + .map(_.split(",")) + .getOrElse(Array.empty[String]) + + val maybeDependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { + Some(new DependencyResolutionStep( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath)) + } else { + None + } + + Seq( + initialSubmissionStep, + driverAddressStep, + kubernetesCredentialsStep) ++ + maybeDependencyResolutionStep.toSeq + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala new file mode 100644 index 0000000000000..4d17608c602d8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import java.util.{Collections, UUID} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.control.NonFatal + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.KubernetesClient + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Encapsulates arguments to the submission client. + * + * @param mainAppResource the main application resource if any + * @param mainClass the main class of the application to run + * @param driverArgs arguments to the driver + */ +private[spark] case class ClientArguments( + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) + +private[spark] object ClientArguments { + + def fromCommandLineArgs(args: Array[String]): ClientArguments = { + var mainAppResource: Option[MainAppResource] = None + var mainClass: Option[String] = None + val driverArgs = mutable.ArrayBuffer.empty[String] + + args.sliding(2, 2).toList.foreach { + case Array("--primary-java-resource", primaryJavaResource: String) => + mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + case Array("--main-class", clazz: String) => + mainClass = Some(clazz) + case Array("--arg", arg: String) => + driverArgs += arg + case other => + val invalid = other.mkString(" ") + throw new RuntimeException(s"Unknown arguments: $invalid") + } + + require(mainClass.isDefined, "Main class must be specified via --main-class") + + ClientArguments( + mainAppResource, + mainClass.get, + driverArgs.toArray) + } +} + +/** + * Submits a Spark application to run on Kubernetes by creating the driver pod and starting a + * watcher that monitors and logs the application status. Waits for the application to terminate if + * spark.kubernetes.submission.waitAppCompletion is true. + * + * @param submissionSteps steps that collectively configure the driver + * @param submissionSparkConf the submission client Spark configuration + * @param kubernetesClient the client to talk to the Kubernetes API server + * @param waitForAppCompletion a flag indicating whether the client should wait for the application + * to complete + * @param appName the application name + * @param loggingPodStatusWatcher a watcher that monitors and logs the application status + */ +private[spark] class Client( + submissionSteps: Seq[DriverConfigurationStep], + submissionSparkConf: SparkConf, + kubernetesClient: KubernetesClient, + waitForAppCompletion: Boolean, + appName: String, + loggingPodStatusWatcher: LoggingPodStatusWatcher) extends Logging { + + private val driverJavaOptions = submissionSparkConf.get( + org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + + /** + * Run command that initializes a DriverSpec that will be updated after each + * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec + * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources + */ + def run(): Unit = { + var currentDriverSpec = KubernetesDriverSpec.initialSpec(submissionSparkConf) + // submissionSteps contain steps necessary to take, to resolve varying + // client arguments that are passed in, created by orchestrator + for (nextStep <- submissionSteps) { + currentDriverSpec = nextStep.configureDriver(currentDriverSpec) + } + + val resolvedDriverJavaOpts = currentDriverSpec + .driverSparkConf + // Remove this as the options are instead extracted and set individually below using + // environment variables with prefix SPARK_JAVA_OPT_. + .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + .getAll + .map { + case (confKey, confValue) => s"-D$confKey=$confValue" + } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { + case (option, index) => + new EnvVarBuilder() + .withName(s"$ENV_JAVA_OPT_PREFIX$index") + .withValue(option) + .build() + } + + val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + .addAllToEnv(driverJavaOptsEnvs.asJava) + .build() + val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + .editSpec() + .addToContainers(resolvedDriverContainer) + .endSpec() + .build() + + Utils.tryWithResource( + kubernetesClient + .pods() + .withName(resolvedDriverPod.getMetadata.getName) + .watch(loggingPodStatusWatcher)) { _ => + val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) + try { + if (currentDriverSpec.otherKubernetesResources.nonEmpty) { + val otherKubernetesResources = currentDriverSpec.otherKubernetesResources + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() + } + } catch { + case NonFatal(e) => + kubernetesClient.pods().delete(createdDriverPod) + throw e + } + + if (waitForAppCompletion) { + logInfo(s"Waiting for application $appName to finish...") + loggingPodStatusWatcher.awaitCompletion() + logInfo(s"Application $appName finished.") + } else { + logInfo(s"Deployed Spark application $appName into Kubernetes.") + } + } + } + + // Add a OwnerReference to the given resources making the driver pod an owner of them so when + // the driver pod is deleted, the resources are garbage collected. + private def addDriverOwnerReference(driverPod: Pod, resources: Seq[HasMetadata]): Unit = { + val driverPodOwnerReference = new OwnerReferenceBuilder() + .withName(driverPod.getMetadata.getName) + .withApiVersion(driverPod.getApiVersion) + .withUid(driverPod.getMetadata.getUid) + .withKind(driverPod.getKind) + .withController(true) + .build() + resources.foreach { resource => + val originalMetadata = resource.getMetadata + originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) + } + } +} + +/** + * Main class and entry point of application submission in KUBERNETES mode. + */ +private[spark] class KubernetesClientApplication extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + val parsedArguments = ClientArguments.fromCommandLineArgs(args) + run(parsedArguments, conf) + } + + private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { + val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + // For constructing the app ID, we can't use the Spark application name, as the app ID is going + // to be added as a label to group resources belonging to the same application. Label values are + // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate + // a unique app ID (captured by spark.app.id) in the format below. + val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" + val launchTime = System.currentTimeMillis() + val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + // The master URL has been checked for validity already in SparkSubmit. + // We just need to get rid of the "k8s:" prefix here. + val master = sparkConf.get("spark.master").substring("k8s:".length) + val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None + + val loggingPodStatusWatcher = new LoggingPodStatusWatcherImpl( + kubernetesAppId, loggingInterval) + + val configurationStepsOrchestrator = new DriverConfigurationStepsOrchestrator( + namespace, + kubernetesAppId, + launchTime, + clientArguments.mainAppResource, + appName, + clientArguments.mainClass, + clientArguments.driverArgs, + sparkConf) + + Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( + master, + Some(namespace), + KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX, + sparkConf, + None, + None)) { kubernetesClient => + val client = new Client( + configurationStepsOrchestrator.getAllConfigurationSteps(), + sparkConf, + kubernetesClient, + waitForAppCompletion, + appName, + loggingPodStatusWatcher) + client.run() + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala new file mode 100644 index 0000000000000..db13f09387ef9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} + +import org.apache.spark.SparkConf + +/** + * Represents the components and characteristics of a Spark driver. The driver can be considered + * as being comprised of the driver pod itself, any other Kubernetes resources that the driver + * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver + * container should be operated on via the specific field of this case class as opposed to trying + * to edit the container directly on the pod. The driver container should be attached at the + * end of executing all submission steps. + */ +private[spark] case class KubernetesDriverSpec( + driverPod: Pod, + driverContainer: Container, + otherKubernetesResources: Seq[HasMetadata], + driverSparkConf: SparkConf) + +private[spark] object KubernetesDriverSpec { + def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { + KubernetesDriverSpec( + // Set new metadata and a new spec so that submission steps can use + // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. + new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), + new ContainerBuilder().build(), + Seq.empty[HasMetadata], + initialSparkConf.clone()) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala new file mode 100644 index 0000000000000..a38cf55fc3d58 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import java.io.File + +import org.apache.spark.util.Utils + +private[spark] object KubernetesFileUtils { + + /** + * For the given collection of file URIs, resolves them as follows: + * - File URIs with scheme file:// are resolved to the given download path. + * - File URIs with scheme local:// resolve to just the path of the URI. + * - Otherwise, the URIs are returned as-is. + */ + def resolveFileUris( + fileUris: Iterable[String], + fileDownloadPath: String): Iterable[String] = { + fileUris.map { uri => + resolveFileUri(uri, fileDownloadPath, false) + } + } + + /** + * If any file uri has any scheme other than local:// it is mapped as if the file + * was downloaded to the file download path. Otherwise, it is mapped to the path + * part of the URI. + */ + def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { + fileUris.map { uri => + resolveFileUri(uri, fileDownloadPath, true) + } + } + + private def resolveFileUri( + uri: String, + fileDownloadPath: String, + assumesDownloaded: Boolean): String = { + val fileUri = Utils.resolveURI(uri) + val fileScheme = Option(fileUri.getScheme).getOrElse("file") + fileScheme match { + case "local" => + fileUri.getPath + case _ => + if (assumesDownloaded || fileScheme == "file") { + val fileName = new File(fileUri.getPath).getName + s"$fileDownloadPath/$fileName" + } else { + uri + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala new file mode 100644 index 0000000000000..173ac541626a7 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, Time} +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[k8s] trait LoggingPodStatusWatcher extends Watcher[Pod] { + def awaitCompletion(): Unit +} + +/** + * A monitor for the running Kubernetes pod of a Spark application. Status logging occurs on + * every state change and also at an interval for liveness. + * + * @param appId application ID. + * @param maybeLoggingInterval ms between each state request. If provided, must be a positive + * number. + */ +private[k8s] class LoggingPodStatusWatcherImpl( + appId: String, + maybeLoggingInterval: Option[Long]) + extends LoggingPodStatusWatcher with Logging { + + private val podCompletedFuture = new CountDownLatch(1) + // start timer for periodic logging + private val scheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("logging-pod-status-watcher") + private val logRunnable: Runnable = new Runnable { + override def run() = logShortStatus() + } + + private var pod = Option.empty[Pod] + + private def phase: String = pod.map(_.getStatus.getPhase).getOrElse("unknown") + + def start(): Unit = { + maybeLoggingInterval.foreach { interval => + scheduler.scheduleAtFixedRate(logRunnable, 0, interval, TimeUnit.MILLISECONDS) + } + } + + override def eventReceived(action: Action, pod: Pod): Unit = { + this.pod = Option(pod) + action match { + case Action.DELETED | Action.ERROR => + closeWatch() + + case _ => + logLongStatus() + if (hasCompleted()) { + closeWatch() + } + } + } + + override def onClose(e: KubernetesClientException): Unit = { + logDebug(s"Stopping watching application $appId with last-observed phase $phase") + closeWatch() + } + + private def logShortStatus() = { + logInfo(s"Application status for $appId (phase: $phase)") + } + + private def logLongStatus() = { + logInfo("State changed, new state: " + pod.map(formatPodState).getOrElse("unknown")) + } + + private def hasCompleted(): Boolean = { + phase == "Succeeded" || phase == "Failed" + } + + private def closeWatch(): Unit = { + podCompletedFuture.countDown() + scheduler.shutdown() + } + + private def formatPodState(pod: Pod): String = { + val details = Seq[(String, String)]( + // pod metadata + ("pod name", pod.getMetadata.getName), + ("namespace", pod.getMetadata.getNamespace), + ("labels", pod.getMetadata.getLabels.asScala.mkString(", ")), + ("pod uid", pod.getMetadata.getUid), + ("creation time", formatTime(pod.getMetadata.getCreationTimestamp)), + + // spec details + ("service account name", pod.getSpec.getServiceAccountName), + ("volumes", pod.getSpec.getVolumes.asScala.map(_.getName).mkString(", ")), + ("node name", pod.getSpec.getNodeName), + + // status + ("start time", formatTime(pod.getStatus.getStartTime)), + ("container images", + pod.getStatus.getContainerStatuses + .asScala + .map(_.getImage) + .mkString(", ")), + ("phase", pod.getStatus.getPhase), + ("status", pod.getStatus.getContainerStatuses.toString) + ) + + formatPairsBundle(details) + } + + private def formatPairsBundle(pairs: Seq[(String, String)]) = { + // Use more loggable format if value is null or empty + pairs.map { + case (k, v) => s"\n\t $k: ${Option(v).filter(_.nonEmpty).getOrElse("N/A")}" + }.mkString("") + } + + override def awaitCompletion(): Unit = { + podCompletedFuture.await() + logInfo(pod.map { p => + s"Container final statuses:\n\n${containersDescription(p)}" + }.getOrElse("No containers were found in the driver pod.")) + } + + private def containersDescription(p: Pod): String = { + p.getStatus.getContainerStatuses.asScala.map { status => + Seq( + ("Container name", status.getName), + ("Container image", status.getImage)) ++ + containerStatusDescription(status) + }.map(formatPairsBundle).mkString("\n\n") + } + + private def containerStatusDescription( + containerStatus: ContainerStatus): Seq[(String, String)] = { + val state = containerStatus.getState + Option(state.getRunning) + .orElse(Option(state.getTerminated)) + .orElse(Option(state.getWaiting)) + .map { + case running: ContainerStateRunning => + Seq( + ("Container state", "Running"), + ("Container started at", formatTime(running.getStartedAt))) + case waiting: ContainerStateWaiting => + Seq( + ("Container state", "Waiting"), + ("Pending reason", waiting.getReason)) + case terminated: ContainerStateTerminated => + Seq( + ("Container state", "Terminated"), + ("Exit code", terminated.getExitCode.toString)) + case unknown => + throw new SparkException(s"Unexpected container status type ${unknown.getClass}.") + }.getOrElse(Seq(("Container state", "N/A"))) + } + + private def formatTime(time: Time): String = { + if (time != null) time.getTime else "N/A" + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala new file mode 100644 index 0000000000000..cca9f4627a1f6 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +private[spark] sealed trait MainAppResource + +private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala new file mode 100644 index 0000000000000..ba2a11b9e6689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.ConfigurationUtils +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} + +/** + * Represents the initial setup required for the driver. + */ +private[spark] class BaseDriverConfigurationStep( + kubernetesAppId: String, + kubernetesResourceNamePrefix: String, + driverLabels: Map[String, String], + dockerImagePullPolicy: String, + appName: String, + mainClass: String, + appArgs: Array[String], + submissionSparkConf: SparkConf) extends DriverConfigurationStep { + + private val kubernetesDriverPodName = submissionSparkConf.get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"$kubernetesResourceNamePrefix-driver") + + private val driverExtraClasspath = submissionSparkConf.get( + DRIVER_CLASS_PATH) + + private val driverDockerImage = submissionSparkConf + .get(DRIVER_DOCKER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver Docker image")) + + // CPU settings + private val driverCpuCores = submissionSparkConf.getOption("spark.driver.cores").getOrElse("1") + private val driverLimitCores = submissionSparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = submissionSparkConf.get( + DRIVER_MEMORY) + private val driverMemoryString = submissionSparkConf.get( + DRIVER_MEMORY.key, + DRIVER_MEMORY.defaultValueString) + private val memoryOverheadMiB = submissionSparkConf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val driverContainerMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => + new EnvVarBuilder() + .withName(ENV_SUBMIT_EXTRA_CLASSPATH) + .withValue(classPath) + .build() + } + + val driverCustomAnnotations = ConfigurationUtils + .parsePrefixedKeyValuePairs( + submissionSparkConf, + KUBERNETES_DRIVER_ANNOTATION_PREFIX) + require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), + s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + + " Spark bookkeeping operations.") + + val driverCustomEnvs = submissionSparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val allDriverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) + + val nodeSelector = ConfigurationUtils.parsePrefixedKeyValuePairs( + submissionSparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryMiB}Mi") + .build() + val driverMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${driverContainerMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(driverSpec.driverContainer) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .addAllToEnv(driverCustomEnvs.asJava) + .addToEnv(driverExtraClasspathEnv.toSeq: _*) + .addNewEnv() + .withName(ENV_DRIVER_MEMORY) + .withValue(driverMemoryString) + .endEnv() + .addNewEnv() + .withName(ENV_DRIVER_MAIN_CLASS) + .withValue(mainClass) + .endEnv() + .addNewEnv() + .withName(ENV_DRIVER_ARGS) + .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .endEnv() + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .endResources() + .build() + + val baseDriverPod = new PodBuilder(driverSpec.driverPod) + .editOrNewMetadata() + .withName(kubernetesDriverPodName) + .addToLabels(driverLabels.asJava) + .addToAnnotations(allDriverAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val resolvedSparkConf = driverSpec.driverSparkConf.clone() + .setIfMissing(KUBERNETES_DRIVER_POD_NAME, kubernetesDriverPodName) + .set("spark.app.id", kubernetesAppId) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, kubernetesResourceNamePrefix) + + driverSpec.copy( + driverPod = baseDriverPod, + driverSparkConf = resolvedSparkConf, + driverContainer = driverContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala new file mode 100644 index 0000000000000..44e0ecffc0e93 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.File + +import io.fabric8.kubernetes.api.model.ContainerBuilder + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, KubernetesFileUtils} + +/** + * Step that configures the classpath, spark.jars, and spark.files for the driver given that the + * user may provide remote files or files with local:// schemes. + */ +private[spark] class DependencyResolutionStep( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + localFilesDownloadPath: String) extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val resolvedSparkJars = KubernetesFileUtils.resolveFileUris(sparkJars, jarsDownloadPath) + val resolvedSparkFiles = KubernetesFileUtils.resolveFileUris( + sparkFiles, localFilesDownloadPath) + val sparkConfResolvedSparkDependencies = driverSpec.driverSparkConf.clone() + if (resolvedSparkJars.nonEmpty) { + sparkConfResolvedSparkDependencies.set("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + sparkConfResolvedSparkDependencies.set("spark.files", resolvedSparkFiles.mkString(",")) + } + val resolvedClasspath = KubernetesFileUtils.resolveFilePaths(sparkJars, jarsDownloadPath) + val driverContainerWithResolvedClasspath = if (resolvedClasspath.nonEmpty) { + new ContainerBuilder(driverSpec.driverContainer) + .addNewEnv() + .withName(ENV_MOUNTED_CLASSPATH) + .withValue(resolvedClasspath.mkString(File.pathSeparator)) + .endEnv() + .build() + } else { + driverSpec.driverContainer + } + driverSpec.copy( + driverContainer = driverContainerWithResolvedClasspath, + driverSparkConf = sparkConfResolvedSparkDependencies) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala new file mode 100644 index 0000000000000..c99c0436cf25f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +/** + * Represents a step in preparing the Kubernetes driver. + */ +private[spark] trait DriverConfigurationStep { + + /** + * Apply some transformation to the previous state of the driver to add a new feature to it. + */ + def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala new file mode 100644 index 0000000000000..ccc18908658f1 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +/** + * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials + * to request executors. + */ +private[spark] class DriverKubernetesCredentialsStep( + submissionSparkConf: SparkConf, + kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { + + private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = submissionSparkConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = submissionSparkConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = submissionSparkConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val driverSparkConf = driverSpec.driverSparkConf.clone() + + val oauthTokenBase64 = submissionSparkConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( + driverSparkConf, + oauthTokenBase64, + caCertDataBase64, + clientKeyDataBase64, + clientCertDataBase64) + + val kubernetesCredentialsSecret = createCredentialsSecret( + oauthTokenBase64, + caCertDataBase64, + clientKeyDataBase64, + clientCertDataBase64) + + val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => + new PodBuilder(driverSpec.driverPod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() + .endVolume() + .endSpec() + .build() + }.getOrElse( + driverServiceAccount.map { account => + new PodBuilder(driverSpec.driverPod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(driverSpec.driverPod) + ) + + val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => + new ContainerBuilder(driverSpec.driverContainer) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + }.getOrElse(driverSpec.driverContainer) + + driverSpec.copy( + driverPod = driverPodWithMountedKubernetesCredentials, + otherKubernetesResources = + driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, + driverSparkConf = driverSparkConfWithCredentialsLocations, + driverContainer = driverContainerWithMountedSecretVolume) + } + + private def createCredentialsSecret( + driverOAuthTokenBase64: Option[String], + driverCaCertDataBase64: Option[String], + driverClientKeyDataBase64: Option[String], + driverClientCertDataBase64: Option[String]): Option[Secret] = { + val allSecretData = + resolveSecretData( + driverClientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + driverClientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + driverCaCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + driverOAuthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + if (allSecretData.isEmpty) { + None + } else { + Some(new SecretBuilder() + .withNewMetadata() + .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") + .endMetadata() + .withData(allSecretData.asJava) + .build()) + } + } + + private def setDriverPodKubernetesCredentialLocations( + driverSparkConf: SparkConf, + driverOauthTokenBase64: Option[String], + driverCaCertDataBase64: Option[String], + driverClientKeyDataBase64: Option[String], + driverClientCertDataBase64: Option[String]): SparkConf = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + driverOauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + driverClientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + driverClientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + driverCaCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val sparkConfWithCredentialLocations = driverSparkConf + .setOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + resolvedMountedCaCertFile) + .setOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + resolvedMountedClientKeyFile) + .setOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + resolvedMountedClientCertFile) + .setOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", + resolvedMountedOAuthTokenFile) + + // Redact all OAuth token values + sparkConfWithCredentialLocations + .getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) + .foreach { + sparkConfWithCredentialLocations.set(_, "") + } + sparkConfWithCredentialLocations + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + submissionSparkConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { + new OptionSettableSparkConf(sparkConf) + } +} + +private class OptionSettableSparkConf(sparkConf: SparkConf) { + def setOption(configEntry: String, option: Option[String]): SparkConf = { + option.foreach { opt => + sparkConf.set(configEntry, opt) + } + sparkConf + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala new file mode 100644 index 0000000000000..696d11f15ed95 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.ServiceBuilder + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.internal.Logging +import org.apache.spark.util.Clock + +/** + * Allows the driver to be reachable by executor pods through a headless service. The service's + * ports should correspond to the ports that the executor will reach the pod at for RPC. + */ +private[spark] class DriverServiceBootstrapStep( + kubernetesResourceNamePrefix: String, + driverLabels: Map[String, String], + submissionSparkConf: SparkConf, + clock: Clock) extends DriverConfigurationStep with Logging { + import DriverServiceBootstrapStep._ + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + require(submissionSparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(submissionSparkConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + val preferredServiceName = s"$kubernetesResourceNamePrefix$DRIVER_SVC_POSTFIX" + val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + val driverPort = submissionSparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = submissionSparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(driverLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + + val namespace = submissionSparkConf.get(KUBERNETES_NAMESPACE) + val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" + val resolvedSparkConf = driverSpec.driverSparkConf.clone() + .set(DRIVER_HOST_KEY, driverHostname) + .set("spark.driver.port", driverPort.toString) + .set( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) + + driverSpec.copy( + driverSparkConf = resolvedSparkConf, + otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) + } +} + +private[spark] object DriverServiceBootstrapStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index f79155b117b67..9d8f3b912c33d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} import org.apache.spark.util.Utils /** @@ -46,8 +47,7 @@ private[spark] trait ExecutorPodFactory { private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) extends ExecutorPodFactory { - private val executorExtraClasspath = - sparkConf.get(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( sparkConf, @@ -81,13 +81,12 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) private val executorMemoryString = sparkConf.get( - org.apache.spark.internal.config.EXECUTOR_MEMORY.key, - org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) private val memoryOverheadMiB = sparkConf - .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .get(EXECUTOR_MEMORY_OVERHEAD) .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB @@ -129,7 +128,7 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .build() } val executorExtraJavaOptionsEnv = sparkConf - .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .get(EXECUTOR_JAVA_OPTIONS) .map { opts => val delimitedOpts = Utils.splitCommandString(opts) delimitedOpts.zipWithIndex.map { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 68ca6a7622171..b8bb152d17910 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -20,7 +20,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory @@ -33,6 +33,10 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + throw new SparkException("Client mode is currently not supported for Kubernetes.") + } + new TaskSchedulerImpl(sc) } @@ -45,7 +49,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sparkConf.get(KUBERNETES_NAMESPACE)), - APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, sparkConf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala new file mode 100644 index 0000000000000..bf4ec04893204 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import scala.collection.JavaConverters._ + +import com.google.common.collect.Iterables +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, Watch} +import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{doReturn, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep + +class ClientSuite extends SparkFunSuite with BeforeAndAfter { + + private val DRIVER_POD_UID = "pod-id" + private val DRIVER_POD_API_VERSION = "v1" + private val DRIVER_POD_KIND = "pod" + + private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] + private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: Pods = _ + + @Mock + private var namedPods: PodResource[Pod, DoneablePod] = _ + + @Mock + private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + + @Mock + private var resourceList: ResourceList = _ + + private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ + private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + + createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) + createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) + when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { + override def answer(invocation: InvocationOnMock): Pod = { + new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + } + }) + when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) + doReturn(resourceList) + .when(kubernetesClient) + .resourceList(createdResourcesArgumentCaptor.capture()) + } + + test("The client should configure the pod with the submission steps.") { + val submissionClient = new Client( + submissionSteps, + new SparkConf(false), + kubernetesClient, + false, + "spark", + loggingPodStatusWatcher) + submissionClient.run() + val createdPod = createdPodArgumentCaptor.getValue + assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) + assert(createdPod.getMetadata.getLabels.asScala === + Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) + assert(createdPod.getMetadata.getAnnotations.asScala === + Map(SecondTestConfigurationStep.annotationKey -> + SecondTestConfigurationStep.annotationValue)) + assert(createdPod.getSpec.getContainers.size() === 1) + assert(createdPod.getSpec.getContainers.get(0).getName === + SecondTestConfigurationStep.containerName) + } + + test("The client should create the secondary Kubernetes resources.") { + val submissionClient = new Client( + submissionSteps, + new SparkConf(false), + kubernetesClient, + false, + "spark", + loggingPodStatusWatcher) + submissionClient.run() + val createdPod = createdPodArgumentCaptor.getValue + val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues + assert(otherCreatedResources.size === 1) + val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] + assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) + assert(createdResource.getData.asScala === + Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) + val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) + assert(ownerReference.getName === createdPod.getMetadata.getName) + assert(ownerReference.getKind === DRIVER_POD_KIND) + assert(ownerReference.getUid === DRIVER_POD_UID) + assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) + } + + test("The client should attach the driver container with the appropriate JVM options.") { + val sparkConf = new SparkConf(false) + .set("spark.logConf", "true") + .set( + org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, + "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") + val submissionClient = new Client( + submissionSteps, + sparkConf, + kubernetesClient, + false, + "spark", + loggingPodStatusWatcher) + submissionClient.run() + val createdPod = createdPodArgumentCaptor.getValue + val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) + assert(driverContainer.getName === SecondTestConfigurationStep.containerName) + val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => + env.getName.startsWith(ENV_JAVA_OPT_PREFIX) + }.sortBy(_.getName) + assert(driverJvmOptsEnvs.size === 4) + + val expectedJvmOptsValues = Seq( + "-Dspark.logConf=true", + s"-D${SecondTestConfigurationStep.sparkConfKey}=" + + s"${SecondTestConfigurationStep.sparkConfValue}", + "-XX:+HeapDumpOnOutOfMemoryError", + "-XX:+PrintGCDetails") + driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { + case ((resolvedEnv, expectedJvmOpt), index) => + assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") + assert(resolvedEnv.getValue === expectedJvmOpt) + } + } + + test("Waiting for app completion should stall on the watcher") { + val submissionClient = new Client( + submissionSteps, + new SparkConf(false), + kubernetesClient, + true, + "spark", + loggingPodStatusWatcher) + submissionClient.run() + verify(loggingPodStatusWatcher).awaitCompletion() + } + +} + +private object FirstTestConfigurationStep extends DriverConfigurationStep { + + val podName = "test-pod" + val secretName = "test-secret" + val labelKey = "first-submit" + val labelValue = "true" + val secretKey = "secretKey" + val secretData = "secretData" + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val modifiedPod = new PodBuilder(driverSpec.driverPod) + .editMetadata() + .withName(podName) + .addToLabels(labelKey, labelValue) + .endMetadata() + .build() + val additionalResource = new SecretBuilder() + .withNewMetadata() + .withName(secretName) + .endMetadata() + .addToData(secretKey, secretData) + .build() + driverSpec.copy( + driverPod = modifiedPod, + otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) + } +} + +private object SecondTestConfigurationStep extends DriverConfigurationStep { + + val annotationKey = "second-submit" + val annotationValue = "submitted" + val sparkConfKey = "spark.custom-conf" + val sparkConfValue = "custom-conf-value" + val containerName = "driverContainer" + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val modifiedPod = new PodBuilder(driverSpec.driverPod) + .editMetadata() + .addToAnnotations(annotationKey, annotationValue) + .endMetadata() + .build() + val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) + val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) + .withName(containerName) + .build() + driverSpec.copy( + driverPod = modifiedPod, + driverSparkConf = resolvedSparkConf, + driverContainer = modifiedContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala new file mode 100644 index 0000000000000..c7291d49b465e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config.DRIVER_DOCKER_IMAGE +import org.apache.spark.deploy.k8s.submit.steps._ + +class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { + + private val NAMESPACE = "default" + private val DRIVER_IMAGE = "driver-image" + private val APP_ID = "spark-app-id" + private val LAUNCH_TIME = 975256L + private val APP_NAME = "spark" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2") + + test("Base submission steps with a main app resource.") { + val sparkConf = new SparkConf(false) + .set(DRIVER_DOCKER_IMAGE, DRIVER_IMAGE) + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigurationStepsOrchestrator( + NAMESPACE, + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BaseDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep] + ) + } + + test("Base submission steps without a main app resource.") { + val sparkConf = new SparkConf(false) + .set(DRIVER_DOCKER_IMAGE, DRIVER_IMAGE) + val orchestrator = new DriverConfigurationStepsOrchestrator( + NAMESPACE, + APP_ID, + LAUNCH_TIME, + Option.empty, + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BaseDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep] + ) + } + + private def validateStepTypes( + orchestrator: DriverConfigurationStepsOrchestrator, + types: Class[_ <: DriverConfigurationStep]*): Unit = { + val steps = orchestrator.getAllConfigurationSteps() + assert(steps.size === types.size) + assert(steps.map(_.getClass) === types) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala new file mode 100644 index 0000000000000..83c5f98254829 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +class BaseDriverConfigurationStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val DOCKER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" + + test("Set all possible configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(DRIVER_DOCKER_IMAGE, "spark-driver:latest") + .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) + .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") + .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") + + val submissionStep = new BaseDriverConfigurationStep( + APP_ID, + RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + DOCKER_IMAGE_PULL_POLICY, + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() + val baseDriverSpec = KubernetesDriverSpec( + driverPod = basePod, + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) + + assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) + assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") + assert(preparedDriverSpec.driverContainer.getImagePullPolicy === DOCKER_IMAGE_PULL_POLICY) + + assert(preparedDriverSpec.driverContainer.getEnv.size === 7) + val envs = preparedDriverSpec.driverContainer + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") + assert(envs(ENV_DRIVER_MEMORY) === "256M") + assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) + assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") + assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") + + assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = preparedDriverSpec.driverContainer.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "256Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + val expectedAnnotations = Map( + CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, + SPARK_APP_NAME_ANNOTATION -> APP_NAME) + assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) + assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") + + val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) + assert(resolvedSparkConf === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala new file mode 100644 index 0000000000000..991b03cafb76c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.File + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +class DependencyResolutionStepSuite extends SparkFunSuite { + + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/apps/jars/jar1.jar", + "file:///home/user/apps/jars/jar2.jar", + "local:///var/apps/jars/jar3.jar") + + private val SPARK_FILES = Seq( + "file:///home/user/apps/files/file1.txt", + "hdfs://localhost:9000/apps/files/file2.txt", + "local:///var/apps/files/file3.txt") + + private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" + private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" + + test("Added dependencies should be resolved in Spark configuration and environment") { + val dependencyResolutionStep = new DependencyResolutionStep( + SPARK_JARS, + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH) + val driverPod = new PodBuilder().build() + val baseDriverSpec = KubernetesDriverSpec( + driverPod = driverPod, + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) + assert(preparedDriverSpec.driverPod === driverPod) + assert(preparedDriverSpec.otherKubernetesResources.isEmpty) + val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet + val expectedResolvedSparkJars = Set( + "hdfs://localhost:9000/apps/jars/jar1.jar", + s"$JARS_DOWNLOAD_PATH/jar2.jar", + "/var/apps/jars/jar3.jar") + assert(resolvedSparkJars === expectedResolvedSparkJars) + val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet + val expectedResolvedSparkFiles = Set( + s"$FILES_DOWNLOAD_PATH/file1.txt", + s"hdfs://localhost:9000/apps/files/file2.txt", + s"/var/apps/files/file3.txt") + assert(resolvedSparkFiles === expectedResolvedSparkFiles) + val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala + assert(driverEnv.size === 1) + assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) + val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet + val expectedResolvedDriverClasspath = Set( + s"$JARS_DOWNLOAD_PATH/jar1.jar", + s"$JARS_DOWNLOAD_PATH/jar2.jar", + "/var/apps/jars/jar3.jar") + assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala new file mode 100644 index 0000000000000..64553d25883bb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.File + +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.util.Utils + +class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private var credentialsTempDirectory: File = _ + private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( + driverPod = new PodBuilder().build(), + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + + before { + credentialsTempDirectory = Utils.createTempDir() + } + + after { + credentialsTempDirectory.delete() + } + + test("Don't set any credentials") { + val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( + new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) + val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) + assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) + assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) + assert(preparedDriverSpec.otherKubernetesResources.isEmpty) + assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + } + + test("Only set credentials that are manually mounted.") { + val submissionSparkConf = new SparkConf(false) + .set( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", + "/mnt/secrets/my-token.txt") + .set( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "/mnt/secrets/my-key.pem") + .set( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "/mnt/secrets/my-cert.pem") + .set( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "/mnt/secrets/my-ca.pem") + + val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( + submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) + val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) + assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) + assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) + assert(preparedDriverSpec.otherKubernetesResources.isEmpty) + assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + } + + test("Mount credentials from the submission client as a secret.") { + val caCertFile = writeCredentials("ca.pem", "ca-cert") + val clientKeyFile = writeCredentials("key.pem", "key") + val clientCertFile = writeCredentials("cert.pem", "cert") + val submissionSparkConf = new SparkConf(false) + .set( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX", + "token") + .set( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + clientKeyFile.getAbsolutePath) + .set( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + clientCertFile.getAbsolutePath) + .set( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + caCertFile.getAbsolutePath) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( + submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) + val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( + BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val expectedSparkConf = Map( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH, + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + DRIVER_CREDENTIALS_CLIENT_KEY_PATH, + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + DRIVER_CREDENTIALS_CLIENT_CERT_PATH, + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + DRIVER_CREDENTIALS_CA_CERT_PATH, + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + clientKeyFile.getAbsolutePath, + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + clientCertFile.getAbsolutePath, + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + caCertFile.getAbsolutePath) + assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) + assert(preparedDriverSpec.otherKubernetesResources.size === 1) + val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + assert(credentialsSecret.getMetadata.getName === + s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") + val decodedSecretData = credentialsSecret.getData.asScala.map { data => + (data._1, new String(BaseEncoding.base64().decode(data._2), Charsets.UTF_8)) + } + val expectedSecretData = Map( + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME -> "ca-cert", + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME -> "token", + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") + assert(decodedSecretData === expectedSecretData) + val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + assert(driverPodVolumes.size === 1) + assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + assert(driverPodVolumes.head.getSecret != null) + assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) + val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + assert(driverContainerVolumeMount.size === 1) + assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + } + + private def writeCredentials(credentialsFileName: String, credentialsContents: String): File = { + val credentialsFile = new File(credentialsTempDirectory, credentialsFileName) + Files.write(credentialsContents, credentialsFile, Charsets.UTF_8) + credentialsFile + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala new file mode 100644 index 0000000000000..006ce2668f8a0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.util.Clock + +class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + val configurationStep = new DriverServiceBootstrapStep( + SHORT_RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), + clock) + val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) + val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) + assert(resolvedDriverSpec.otherKubernetesResources.size === 1) + assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) + val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceBootstrapStep( + SHORT_RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + clock) + val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) + val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceBootstrapStep( + SHORT_RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + sparkConf, + clock) + val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) + val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", + resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) + assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === + DEFAULT_DRIVER_PORT.toString) + assert(resolvedDriverSpec.driverSparkConf.get( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) + } + + test("Long prefixes should switch to using a generated name.") { + val configurationStep = new DriverServiceBootstrapStep( + LONG_RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + clock) + when(clock.getTimeMillis()).thenReturn(10000) + val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) + val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) + val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + val configurationStep = new DriverServiceBootstrapStep( + LONG_RESOURCE_NAME_PREFIX, + DRIVER_LABELS, + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + clock) + try { + configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: SparkConf, expectedHostName: String): Unit = { + assert(driverSparkConf.get( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile new file mode 100644 index 0000000000000..d16349559466d --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM spark-base + +# Before building the docker image, first build and make a Spark distribution following +# the instructions in http://spark.apache.org/docs/latest/building-spark.html. +# If this docker file is being used in the context of building your images from a Spark +# distribution, the docker build command should be invoked from the top level directory +# of the Spark distribution. E.g.: +# docker build -t spark-driver:latest -f dockerfiles/spark-base/Dockerfile . + +COPY examples /opt/spark/examples + +CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ + env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ + readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ + if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile new file mode 100644 index 0000000000000..0e38169b8efdc --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM spark-base + +# Before building the docker image, first build and make a Spark distribution following +# the instructions in http://spark.apache.org/docs/latest/building-spark.html. +# If this docker file is being used in the context of building your images from a Spark +# distribution, the docker build command should be invoked from the top level directory +# of the Spark distribution. E.g.: +# docker build -t spark-executor:latest -f dockerfiles/spark-base/Dockerfile . + +COPY examples /opt/spark/examples + +CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ + env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ + readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ + if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile new file mode 100644 index 0000000000000..20316c9c5098a --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM openjdk:8-alpine + +# Before building the docker image, first build and make a Spark distribution following +# the instructions in http://spark.apache.org/docs/latest/building-spark.html. +# If this docker file is being used in the context of building your images from a Spark +# distribution, the docker build command should be invoked from the top level directory +# of the Spark distribution. E.g.: +# docker build -t spark-base:latest -f dockerfiles/spark-base/Dockerfile . + +RUN set -ex && \ + apk upgrade --no-cache && \ + apk add --no-cache bash tini libc6-compat && \ + mkdir -p /opt/spark && \ + mkdir -p /opt/spark/work-dir \ + touch /opt/spark/RELEASE && \ + rm /bin/sh && \ + ln -sv /bin/bash /bin/sh && \ + chgrp root /etc/passwd && chmod ug+rw /etc/passwd + +COPY jars /opt/spark/jars +COPY bin /opt/spark/bin +COPY sbin /opt/spark/sbin +COPY conf /opt/spark/conf +COPY dockerfiles/spark-base/entrypoint.sh /opt/ + +ENV SPARK_HOME /opt/spark + +WORKDIR /opt/spark/work-dir + +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh new file mode 100644 index 0000000000000..82559889f4beb --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +uidentry=$(getent passwd $myuid) + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +# Execute the container CMD under tini for better hygiene +/sbin/tini -s -- "$@" diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index e1af8ba087d6e..3ba3ae5ab4401 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -217,20 +217,12 @@ package object config { .intConf .createWithDefault(1) - private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead") - .bytesConf(ByteUnit.MiB) - .createOptional - /* Executor configuration. */ private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores") .intConf .createWithDefault(1) - private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead") - .bytesConf(ByteUnit.MiB) - .createOptional - private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.executor.nodeLabelExpression") .doc("Node label expression for executors.") From 3d82f6eb782315b05453b3a0334d3bc05ab4298a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 11 Dec 2017 15:55:23 -0800 Subject: [PATCH 393/712] [SPARK-22726][TEST] Basic tests for Binary Comparison and ImplicitTypeCasts ## What changes were proposed in this pull request? Before we deliver the Hive compatibility mode, we plan to write a set of test cases that can be easily run in both Spark and Hive sides. We can easily compare whether they are the same or not. When new typeCoercion rules are added, we also can easily track the changes. These test cases can also be backported to the previous Spark versions for determining the changes we made. This PR is the first attempt for improving the test coverage for type coercion compatibility. We generate these test cases for our binary comparison and ImplicitTypeCasts based on the Apache Derby test cases in https://github.com/apache/derby/blob/10.14/java/testing/org/apache/derbyTesting/functionTests/tests/lang/implicitConversions.sql ## How was this patch tested? N/A Author: gatorsmile Closes #19918 from gatorsmile/typeCoercionTests. --- .../typeCoercion/native/binaryComparison.sql | 287 +++ .../typeCoercion/native/implicitTypeCasts.sql | 72 + .../sql-tests/results/datetime.sql.out | 2 + .../results/predicate-functions.sql.out | 152 +- .../native/binaryComparison.sql.out | 2146 +++++++++++++++++ .../native/implicitTypeCasts.sql.out | 354 +++ .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../apache/spark/sql/TPCDSQuerySuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- 9 files changed, 2940 insertions(+), 79 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql new file mode 100644 index 0000000000000..522322ac480be --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/binaryComparison.sql @@ -0,0 +1,287 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +-- Binary Comparison + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT cast(1 as binary) = '1' FROM t; +SELECT cast(1 as binary) > '2' FROM t; +SELECT cast(1 as binary) >= '2' FROM t; +SELECT cast(1 as binary) < '2' FROM t; +SELECT cast(1 as binary) <= '2' FROM t; +SELECT cast(1 as binary) <> '2' FROM t; +SELECT cast(1 as binary) = cast(null as string) FROM t; +SELECT cast(1 as binary) > cast(null as string) FROM t; +SELECT cast(1 as binary) >= cast(null as string) FROM t; +SELECT cast(1 as binary) < cast(null as string) FROM t; +SELECT cast(1 as binary) <= cast(null as string) FROM t; +SELECT cast(1 as binary) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as binary) FROM t; +SELECT '2' > cast(1 as binary) FROM t; +SELECT '2' >= cast(1 as binary) FROM t; +SELECT '2' < cast(1 as binary) FROM t; +SELECT '2' <= cast(1 as binary) FROM t; +SELECT '2' <> cast(1 as binary) FROM t; +SELECT cast(null as string) = cast(1 as binary) FROM t; +SELECT cast(null as string) > cast(1 as binary) FROM t; +SELECT cast(null as string) >= cast(1 as binary) FROM t; +SELECT cast(null as string) < cast(1 as binary) FROM t; +SELECT cast(null as string) <= cast(1 as binary) FROM t; +SELECT cast(null as string) <> cast(1 as binary) FROM t; +SELECT cast(1 as tinyint) = '1' FROM t; +SELECT cast(1 as tinyint) > '2' FROM t; +SELECT cast(1 as tinyint) >= '2' FROM t; +SELECT cast(1 as tinyint) < '2' FROM t; +SELECT cast(1 as tinyint) <= '2' FROM t; +SELECT cast(1 as tinyint) <> '2' FROM t; +SELECT cast(1 as tinyint) = cast(null as string) FROM t; +SELECT cast(1 as tinyint) > cast(null as string) FROM t; +SELECT cast(1 as tinyint) >= cast(null as string) FROM t; +SELECT cast(1 as tinyint) < cast(null as string) FROM t; +SELECT cast(1 as tinyint) <= cast(null as string) FROM t; +SELECT cast(1 as tinyint) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as tinyint) FROM t; +SELECT '2' > cast(1 as tinyint) FROM t; +SELECT '2' >= cast(1 as tinyint) FROM t; +SELECT '2' < cast(1 as tinyint) FROM t; +SELECT '2' <= cast(1 as tinyint) FROM t; +SELECT '2' <> cast(1 as tinyint) FROM t; +SELECT cast(null as string) = cast(1 as tinyint) FROM t; +SELECT cast(null as string) > cast(1 as tinyint) FROM t; +SELECT cast(null as string) >= cast(1 as tinyint) FROM t; +SELECT cast(null as string) < cast(1 as tinyint) FROM t; +SELECT cast(null as string) <= cast(1 as tinyint) FROM t; +SELECT cast(null as string) <> cast(1 as tinyint) FROM t; +SELECT cast(1 as smallint) = '1' FROM t; +SELECT cast(1 as smallint) > '2' FROM t; +SELECT cast(1 as smallint) >= '2' FROM t; +SELECT cast(1 as smallint) < '2' FROM t; +SELECT cast(1 as smallint) <= '2' FROM t; +SELECT cast(1 as smallint) <> '2' FROM t; +SELECT cast(1 as smallint) = cast(null as string) FROM t; +SELECT cast(1 as smallint) > cast(null as string) FROM t; +SELECT cast(1 as smallint) >= cast(null as string) FROM t; +SELECT cast(1 as smallint) < cast(null as string) FROM t; +SELECT cast(1 as smallint) <= cast(null as string) FROM t; +SELECT cast(1 as smallint) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as smallint) FROM t; +SELECT '2' > cast(1 as smallint) FROM t; +SELECT '2' >= cast(1 as smallint) FROM t; +SELECT '2' < cast(1 as smallint) FROM t; +SELECT '2' <= cast(1 as smallint) FROM t; +SELECT '2' <> cast(1 as smallint) FROM t; +SELECT cast(null as string) = cast(1 as smallint) FROM t; +SELECT cast(null as string) > cast(1 as smallint) FROM t; +SELECT cast(null as string) >= cast(1 as smallint) FROM t; +SELECT cast(null as string) < cast(1 as smallint) FROM t; +SELECT cast(null as string) <= cast(1 as smallint) FROM t; +SELECT cast(null as string) <> cast(1 as smallint) FROM t; +SELECT cast(1 as int) = '1' FROM t; +SELECT cast(1 as int) > '2' FROM t; +SELECT cast(1 as int) >= '2' FROM t; +SELECT cast(1 as int) < '2' FROM t; +SELECT cast(1 as int) <= '2' FROM t; +SELECT cast(1 as int) <> '2' FROM t; +SELECT cast(1 as int) = cast(null as string) FROM t; +SELECT cast(1 as int) > cast(null as string) FROM t; +SELECT cast(1 as int) >= cast(null as string) FROM t; +SELECT cast(1 as int) < cast(null as string) FROM t; +SELECT cast(1 as int) <= cast(null as string) FROM t; +SELECT cast(1 as int) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as int) FROM t; +SELECT '2' > cast(1 as int) FROM t; +SELECT '2' >= cast(1 as int) FROM t; +SELECT '2' < cast(1 as int) FROM t; +SELECT '2' <> cast(1 as int) FROM t; +SELECT '2' <= cast(1 as int) FROM t; +SELECT cast(null as string) = cast(1 as int) FROM t; +SELECT cast(null as string) > cast(1 as int) FROM t; +SELECT cast(null as string) >= cast(1 as int) FROM t; +SELECT cast(null as string) < cast(1 as int) FROM t; +SELECT cast(null as string) <> cast(1 as int) FROM t; +SELECT cast(null as string) <= cast(1 as int) FROM t; +SELECT cast(1 as bigint) = '1' FROM t; +SELECT cast(1 as bigint) > '2' FROM t; +SELECT cast(1 as bigint) >= '2' FROM t; +SELECT cast(1 as bigint) < '2' FROM t; +SELECT cast(1 as bigint) <= '2' FROM t; +SELECT cast(1 as bigint) <> '2' FROM t; +SELECT cast(1 as bigint) = cast(null as string) FROM t; +SELECT cast(1 as bigint) > cast(null as string) FROM t; +SELECT cast(1 as bigint) >= cast(null as string) FROM t; +SELECT cast(1 as bigint) < cast(null as string) FROM t; +SELECT cast(1 as bigint) <= cast(null as string) FROM t; +SELECT cast(1 as bigint) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as bigint) FROM t; +SELECT '2' > cast(1 as bigint) FROM t; +SELECT '2' >= cast(1 as bigint) FROM t; +SELECT '2' < cast(1 as bigint) FROM t; +SELECT '2' <= cast(1 as bigint) FROM t; +SELECT '2' <> cast(1 as bigint) FROM t; +SELECT cast(null as string) = cast(1 as bigint) FROM t; +SELECT cast(null as string) > cast(1 as bigint) FROM t; +SELECT cast(null as string) >= cast(1 as bigint) FROM t; +SELECT cast(null as string) < cast(1 as bigint) FROM t; +SELECT cast(null as string) <= cast(1 as bigint) FROM t; +SELECT cast(null as string) <> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) = '1' FROM t; +SELECT cast(1 as decimal(10, 0)) > '2' FROM t; +SELECT cast(1 as decimal(10, 0)) >= '2' FROM t; +SELECT cast(1 as decimal(10, 0)) < '2' FROM t; +SELECT cast(1 as decimal(10, 0)) <> '2' FROM t; +SELECT cast(1 as decimal(10, 0)) <= '2' FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(null as string) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(null as string) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(null as string) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(null as string) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(null as string) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(null as string) FROM t; +SELECT '1' = cast(1 as decimal(10, 0)) FROM t; +SELECT '2' > cast(1 as decimal(10, 0)) FROM t; +SELECT '2' >= cast(1 as decimal(10, 0)) FROM t; +SELECT '2' < cast(1 as decimal(10, 0)) FROM t; +SELECT '2' <= cast(1 as decimal(10, 0)) FROM t; +SELECT '2' <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(null as string) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) = '1' FROM t; +SELECT cast(1 as double) > '2' FROM t; +SELECT cast(1 as double) >= '2' FROM t; +SELECT cast(1 as double) < '2' FROM t; +SELECT cast(1 as double) <= '2' FROM t; +SELECT cast(1 as double) <> '2' FROM t; +SELECT cast(1 as double) = cast(null as string) FROM t; +SELECT cast(1 as double) > cast(null as string) FROM t; +SELECT cast(1 as double) >= cast(null as string) FROM t; +SELECT cast(1 as double) < cast(null as string) FROM t; +SELECT cast(1 as double) <= cast(null as string) FROM t; +SELECT cast(1 as double) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as double) FROM t; +SELECT '2' > cast(1 as double) FROM t; +SELECT '2' >= cast(1 as double) FROM t; +SELECT '2' < cast(1 as double) FROM t; +SELECT '2' <= cast(1 as double) FROM t; +SELECT '2' <> cast(1 as double) FROM t; +SELECT cast(null as string) = cast(1 as double) FROM t; +SELECT cast(null as string) > cast(1 as double) FROM t; +SELECT cast(null as string) >= cast(1 as double) FROM t; +SELECT cast(null as string) < cast(1 as double) FROM t; +SELECT cast(null as string) <= cast(1 as double) FROM t; +SELECT cast(null as string) <> cast(1 as double) FROM t; +SELECT cast(1 as float) = '1' FROM t; +SELECT cast(1 as float) > '2' FROM t; +SELECT cast(1 as float) >= '2' FROM t; +SELECT cast(1 as float) < '2' FROM t; +SELECT cast(1 as float) <= '2' FROM t; +SELECT cast(1 as float) <> '2' FROM t; +SELECT cast(1 as float) = cast(null as string) FROM t; +SELECT cast(1 as float) > cast(null as string) FROM t; +SELECT cast(1 as float) >= cast(null as string) FROM t; +SELECT cast(1 as float) < cast(null as string) FROM t; +SELECT cast(1 as float) <= cast(null as string) FROM t; +SELECT cast(1 as float) <> cast(null as string) FROM t; +SELECT '1' = cast(1 as float) FROM t; +SELECT '2' > cast(1 as float) FROM t; +SELECT '2' >= cast(1 as float) FROM t; +SELECT '2' < cast(1 as float) FROM t; +SELECT '2' <= cast(1 as float) FROM t; +SELECT '2' <> cast(1 as float) FROM t; +SELECT cast(null as string) = cast(1 as float) FROM t; +SELECT cast(null as string) > cast(1 as float) FROM t; +SELECT cast(null as string) >= cast(1 as float) FROM t; +SELECT cast(null as string) < cast(1 as float) FROM t; +SELECT cast(null as string) <= cast(1 as float) FROM t; +SELECT cast(null as string) <> cast(1 as float) FROM t; +-- the following queries return 1 if the search condition is satisfied +-- and returns nothing if the search condition is not satisfied +SELECT '1996-09-09' = date('1996-09-09') FROM t; +SELECT '1996-9-10' > date('1996-09-09') FROM t; +SELECT '1996-9-10' >= date('1996-09-09') FROM t; +SELECT '1996-9-10' < date('1996-09-09') FROM t; +SELECT '1996-9-10' <= date('1996-09-09') FROM t; +SELECT '1996-9-10' <> date('1996-09-09') FROM t; +SELECT cast(null as string) = date('1996-09-09') FROM t; +SELECT cast(null as string)> date('1996-09-09') FROM t; +SELECT cast(null as string)>= date('1996-09-09') FROM t; +SELECT cast(null as string)< date('1996-09-09') FROM t; +SELECT cast(null as string)<= date('1996-09-09') FROM t; +SELECT cast(null as string)<> date('1996-09-09') FROM t; +SELECT date('1996-09-09') = '1996-09-09' FROM t; +SELECT date('1996-9-10') > '1996-09-09' FROM t; +SELECT date('1996-9-10') >= '1996-09-09' FROM t; +SELECT date('1996-9-10') < '1996-09-09' FROM t; +SELECT date('1996-9-10') <= '1996-09-09' FROM t; +SELECT date('1996-9-10') <> '1996-09-09' FROM t; +SELECT date('1996-09-09') = cast(null as string) FROM t; +SELECT date('1996-9-10') > cast(null as string) FROM t; +SELECT date('1996-9-10') >= cast(null as string) FROM t; +SELECT date('1996-9-10') < cast(null as string) FROM t; +SELECT date('1996-9-10') <= cast(null as string) FROM t; +SELECT date('1996-9-10') <> cast(null as string) FROM t; +SELECT '1996-09-09 12:12:12.4' = timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT '1996-09-09 12:12:12.5' > timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT '1996-09-09 12:12:12.5' >= timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT '1996-09-09 12:12:12.5' < timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT '1996-09-09 12:12:12.5' <= timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT '1996-09-09 12:12:12.5' <> timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) = timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) > timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) >= timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) < timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) <= timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT cast(null as string) <> timestamp('1996-09-09 12:12:12.4') FROM t; +SELECT timestamp('1996-09-09 12:12:12.4' )= '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )> '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )>= '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )< '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )<= '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )<> '1996-09-09 12:12:12.4' FROM t; +SELECT timestamp('1996-09-09 12:12:12.4' )= cast(null as string) FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )> cast(null as string) FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )>= cast(null as string) FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )< cast(null as string) FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )<= cast(null as string) FROM t; +SELECT timestamp('1996-09-09 12:12:12.5' )<> cast(null as string) FROM t; +SELECT ' ' = X'0020' FROM t; +SELECT ' ' > X'001F' FROM t; +SELECT ' ' >= X'001F' FROM t; +SELECT ' ' < X'001F' FROM t; +SELECT ' ' <= X'001F' FROM t; +SELECT ' ' <> X'001F' FROM t; +SELECT cast(null as string) = X'0020' FROM t; +SELECT cast(null as string) > X'001F' FROM t; +SELECT cast(null as string) >= X'001F' FROM t; +SELECT cast(null as string) < X'001F' FROM t; +SELECT cast(null as string) <= X'001F' FROM t; +SELECT cast(null as string) <> X'001F' FROM t; +SELECT X'0020' = ' ' FROM t; +SELECT X'001F' > ' ' FROM t; +SELECT X'001F' >= ' ' FROM t; +SELECT X'001F' < ' ' FROM t; +SELECT X'001F' <= ' ' FROM t; +SELECT X'001F' <> ' ' FROM t; +SELECT X'0020' = cast(null as string) FROM t; +SELECT X'001F' > cast(null as string) FROM t; +SELECT X'001F' >= cast(null as string) FROM t; +SELECT X'001F' < cast(null as string) FROM t; +SELECT X'001F' <= cast(null as string) FROM t; +SELECT X'001F' <> cast(null as string) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql new file mode 100644 index 0000000000000..58866f4b18112 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql @@ -0,0 +1,72 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +-- ImplicitTypeCasts + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT 1 + '2' FROM t; +SELECT 1 - '2' FROM t; +SELECT 1 * '2' FROM t; +SELECT 4 / '2' FROM t; +SELECT 1.1 + '2' FROM t; +SELECT 1.1 - '2' FROM t; +SELECT 1.1 * '2' FROM t; +SELECT 4.4 / '2' FROM t; +SELECT 1.1 + '2.2' FROM t; +SELECT 1.1 - '2.2' FROM t; +SELECT 1.1 * '2.2' FROM t; +SELECT 4.4 / '2.2' FROM t; + +-- concatentation +SELECT '$' || cast(1 as smallint) || '$' FROM t; +SELECT '$' || 1 || '$' FROM t; +SELECT '$' || cast(1 as bigint) || '$' FROM t; +SELECT '$' || cast(1.1 as float) || '$' FROM t; +SELECT '$' || cast(1.1 as double) || '$' FROM t; +SELECT '$' || 1.1 || '$' FROM t; +SELECT '$' || cast(1.1 as decimal(8,3)) || '$' FROM t; +SELECT '$' || 'abcd' || '$' FROM t; +SELECT '$' || date('1996-09-09') || '$' FROM t; +SELECT '$' || timestamp('1996-09-09 10:11:12.4' )|| '$' FROM t; + +-- length functions +SELECT length(cast(1 as smallint)) FROM t; +SELECT length(cast(1 as int)) FROM t; +SELECT length(cast(1 as bigint)) FROM t; +SELECT length(cast(1.1 as float)) FROM t; +SELECT length(cast(1.1 as double)) FROM t; +SELECT length(1.1) FROM t; +SELECT length(cast(1.1 as decimal(8,3))) FROM t; +SELECT length('four') FROM t; +SELECT length(date('1996-09-10')) FROM t; +SELECT length(timestamp('1996-09-10 10:11:12.4')) FROM t; + +-- extract +SELECT year( '1996-01-10') FROM t; +SELECT month( '1996-01-10') FROM t; +SELECT day( '1996-01-10') FROM t; +SELECT hour( '10:11:12') FROM t; +SELECT minute( '10:11:12') FROM t; +SELECT second( '10:11:12') FROM t; + +-- like +select 1 like '%' FROM t; +select date('1996-09-10') like '19%' FROM t; +select '1' like 1 FROM t; +select '1 ' like 1 FROM t; +select '1996-09-10' like date('1996-09-10') FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 7b2f46f6c2a66..bbb6851e69c7e 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -44,6 +44,7 @@ struct<> -- !query 4 output + -- !query 5 select current_date, current_timestamp from ttf1 -- !query 5 schema @@ -63,6 +64,7 @@ struct<> -- !query 6 output + -- !query 7 select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2 -- !query 7 schema diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index 8cd0d51da64f5..d51f6d37e4b41 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 32 -- !query 0 @@ -34,225 +34,225 @@ struct<(CAST(1.5 AS DOUBLE) = CAST(1.51 AS DOUBLE)):boolean> false --- !query 3 -select 1 > '1' --- !query 3 schema -struct<(1 > CAST(1 AS INT)):boolean> --- !query 3 output -false - - -- !query 4 -select 2 > '1.0' +select 1 > '1' -- !query 4 schema -struct<(2 > CAST(1.0 AS INT)):boolean> +struct<(1 > CAST(1 AS INT)):boolean> -- !query 4 output -true +false -- !query 5 -select 2 > '2.0' +select 2 > '1.0' -- !query 5 schema -struct<(2 > CAST(2.0 AS INT)):boolean> +struct<(2 > CAST(1.0 AS INT)):boolean> -- !query 5 output -false +true -- !query 6 -select 2 > '2.2' +select 2 > '2.0' -- !query 6 schema -struct<(2 > CAST(2.2 AS INT)):boolean> +struct<(2 > CAST(2.0 AS INT)):boolean> -- !query 6 output false -- !query 7 -select '1.5' > 0.5 +select 2 > '2.2' -- !query 7 schema -struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean> +struct<(2 > CAST(2.2 AS INT)):boolean> -- !query 7 output -true +false -- !query 8 -select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') +select '1.5' > 0.5 -- !query 8 schema -struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(1.5 AS DOUBLE) > CAST(0.5 AS DOUBLE)):boolean> -- !query 8 output -false +true -- !query 9 -select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') -- !query 9 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> -- !query 9 output false -- !query 10 -select 1 >= '1' +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' -- !query 10 schema -struct<(1 >= CAST(1 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> -- !query 10 output -true +false -- !query 11 -select 2 >= '1.0' +select 1 >= '1' -- !query 11 schema -struct<(2 >= CAST(1.0 AS INT)):boolean> +struct<(1 >= CAST(1 AS INT)):boolean> -- !query 11 output true -- !query 12 -select 2 >= '2.0' +select 2 >= '1.0' -- !query 12 schema -struct<(2 >= CAST(2.0 AS INT)):boolean> +struct<(2 >= CAST(1.0 AS INT)):boolean> -- !query 12 output true -- !query 13 -select 2.0 >= '2.2' +select 2 >= '2.0' -- !query 13 schema -struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean> +struct<(2 >= CAST(2.0 AS INT)):boolean> -- !query 13 output -false +true -- !query 14 -select '1.5' >= 0.5 +select 2.0 >= '2.2' -- !query 14 schema -struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean> +struct<(CAST(2.0 AS DOUBLE) >= CAST(2.2 AS DOUBLE)):boolean> -- !query 14 output -true +false -- !query 15 -select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') +select '1.5' >= 0.5 -- !query 15 schema -struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(1.5 AS DOUBLE) >= CAST(0.5 AS DOUBLE)):boolean> -- !query 15 output true -- !query 16 -select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') -- !query 16 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> -- !query 16 output -false +true -- !query 17 -select 1 < '1' +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' -- !query 17 schema -struct<(1 < CAST(1 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> -- !query 17 output false -- !query 18 -select 2 < '1.0' +select 1 < '1' -- !query 18 schema -struct<(2 < CAST(1.0 AS INT)):boolean> +struct<(1 < CAST(1 AS INT)):boolean> -- !query 18 output false -- !query 19 -select 2 < '2.0' +select 2 < '1.0' -- !query 19 schema -struct<(2 < CAST(2.0 AS INT)):boolean> +struct<(2 < CAST(1.0 AS INT)):boolean> -- !query 19 output false -- !query 20 -select 2.0 < '2.2' +select 2 < '2.0' -- !query 20 schema -struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean> +struct<(2 < CAST(2.0 AS INT)):boolean> -- !query 20 output -true +false -- !query 21 -select 0.5 < '1.5' +select 2.0 < '2.2' -- !query 21 schema -struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean> +struct<(CAST(2.0 AS DOUBLE) < CAST(2.2 AS DOUBLE)):boolean> -- !query 21 output true -- !query 22 -select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') +select 0.5 < '1.5' -- !query 22 schema -struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(0.5 AS DOUBLE) < CAST(1.5 AS DOUBLE)):boolean> -- !query 22 output -false +true -- !query 23 -select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') -- !query 23 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> -- !query 23 output -true +false -- !query 24 -select 1 <= '1' +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' -- !query 24 schema -struct<(1 <= CAST(1 AS INT)):boolean> +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> -- !query 24 output true -- !query 25 -select 2 <= '1.0' +select 1 <= '1' -- !query 25 schema -struct<(2 <= CAST(1.0 AS INT)):boolean> +struct<(1 <= CAST(1 AS INT)):boolean> -- !query 25 output -false +true -- !query 26 -select 2 <= '2.0' +select 2 <= '1.0' -- !query 26 schema -struct<(2 <= CAST(2.0 AS INT)):boolean> +struct<(2 <= CAST(1.0 AS INT)):boolean> -- !query 26 output -true +false -- !query 27 -select 2.0 <= '2.2' +select 2 <= '2.0' -- !query 27 schema -struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean> +struct<(2 <= CAST(2.0 AS INT)):boolean> -- !query 27 output true -- !query 28 -select 0.5 <= '1.5' +select 2.0 <= '2.2' -- !query 28 schema -struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean> +struct<(CAST(2.0 AS DOUBLE) <= CAST(2.2 AS DOUBLE)):boolean> -- !query 28 output true -- !query 29 -select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +select 0.5 <= '1.5' -- !query 29 schema -struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +struct<(CAST(0.5 AS DOUBLE) <= CAST(1.5 AS DOUBLE)):boolean> -- !query 29 output true -- !query 30 -select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') -- !query 30 schema -struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> -- !query 30 output true + + +-- !query 31 +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +-- !query 31 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +-- !query 31 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out new file mode 100644 index 0000000000000..fe7bde040707c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out @@ -0,0 +1,2146 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 265 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT cast(1 as binary) = '1' FROM t +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 2 +SELECT cast(1 as binary) > '2' FROM t +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 3 +SELECT cast(1 as binary) >= '2' FROM t +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 4 +SELECT cast(1 as binary) < '2' FROM t +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 5 +SELECT cast(1 as binary) <= '2' FROM t +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 6 +SELECT cast(1 as binary) <> '2' FROM t +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 7 +SELECT cast(1 as binary) = cast(null as string) FROM t +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 8 +SELECT cast(1 as binary) > cast(null as string) FROM t +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 9 +SELECT cast(1 as binary) >= cast(null as string) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 10 +SELECT cast(1 as binary) < cast(null as string) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 11 +SELECT cast(1 as binary) <= cast(null as string) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 12 +SELECT cast(1 as binary) <> cast(null as string) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 + + +-- !query 13 +SELECT '1' = cast(1 as binary) FROM t +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 + + +-- !query 14 +SELECT '2' > cast(1 as binary) FROM t +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 + + +-- !query 15 +SELECT '2' >= cast(1 as binary) FROM t +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 + + +-- !query 16 +SELECT '2' < cast(1 as binary) FROM t +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 + + +-- !query 17 +SELECT '2' <= cast(1 as binary) FROM t +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 + + +-- !query 18 +SELECT '2' <> cast(1 as binary) FROM t +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 + + +-- !query 19 +SELECT cast(null as string) = cast(1 as binary) FROM t +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 + + +-- !query 20 +SELECT cast(null as string) > cast(1 as binary) FROM t +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 + + +-- !query 21 +SELECT cast(null as string) >= cast(1 as binary) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 + + +-- !query 22 +SELECT cast(null as string) < cast(1 as binary) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 + + +-- !query 23 +SELECT cast(null as string) <= cast(1 as binary) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 + + +-- !query 24 +SELECT cast(null as string) <> cast(1 as binary) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 + + +-- !query 25 +SELECT cast(1 as tinyint) = '1' FROM t +-- !query 25 schema +struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 25 output +true + + +-- !query 26 +SELECT cast(1 as tinyint) > '2' FROM t +-- !query 26 schema +struct<(CAST(1 AS TINYINT) > CAST(2 AS TINYINT)):boolean> +-- !query 26 output +false + + +-- !query 27 +SELECT cast(1 as tinyint) >= '2' FROM t +-- !query 27 schema +struct<(CAST(1 AS TINYINT) >= CAST(2 AS TINYINT)):boolean> +-- !query 27 output +false + + +-- !query 28 +SELECT cast(1 as tinyint) < '2' FROM t +-- !query 28 schema +struct<(CAST(1 AS TINYINT) < CAST(2 AS TINYINT)):boolean> +-- !query 28 output +true + + +-- !query 29 +SELECT cast(1 as tinyint) <= '2' FROM t +-- !query 29 schema +struct<(CAST(1 AS TINYINT) <= CAST(2 AS TINYINT)):boolean> +-- !query 29 output +true + + +-- !query 30 +SELECT cast(1 as tinyint) <> '2' FROM t +-- !query 30 schema +struct<(NOT (CAST(1 AS TINYINT) = CAST(2 AS TINYINT))):boolean> +-- !query 30 output +true + + +-- !query 31 +SELECT cast(1 as tinyint) = cast(null as string) FROM t +-- !query 31 schema +struct<(CAST(1 AS TINYINT) = CAST(CAST(NULL AS STRING) AS TINYINT)):boolean> +-- !query 31 output +NULL + + +-- !query 32 +SELECT cast(1 as tinyint) > cast(null as string) FROM t +-- !query 32 schema +struct<(CAST(1 AS TINYINT) > CAST(CAST(NULL AS STRING) AS TINYINT)):boolean> +-- !query 32 output +NULL + + +-- !query 33 +SELECT cast(1 as tinyint) >= cast(null as string) FROM t +-- !query 33 schema +struct<(CAST(1 AS TINYINT) >= CAST(CAST(NULL AS STRING) AS TINYINT)):boolean> +-- !query 33 output +NULL + + +-- !query 34 +SELECT cast(1 as tinyint) < cast(null as string) FROM t +-- !query 34 schema +struct<(CAST(1 AS TINYINT) < CAST(CAST(NULL AS STRING) AS TINYINT)):boolean> +-- !query 34 output +NULL + + +-- !query 35 +SELECT cast(1 as tinyint) <= cast(null as string) FROM t +-- !query 35 schema +struct<(CAST(1 AS TINYINT) <= CAST(CAST(NULL AS STRING) AS TINYINT)):boolean> +-- !query 35 output +NULL + + +-- !query 36 +SELECT cast(1 as tinyint) <> cast(null as string) FROM t +-- !query 36 schema +struct<(NOT (CAST(1 AS TINYINT) = CAST(CAST(NULL AS STRING) AS TINYINT))):boolean> +-- !query 36 output +NULL + + +-- !query 37 +SELECT '1' = cast(1 as tinyint) FROM t +-- !query 37 schema +struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 37 output +true + + +-- !query 38 +SELECT '2' > cast(1 as tinyint) FROM t +-- !query 38 schema +struct<(CAST(2 AS TINYINT) > CAST(1 AS TINYINT)):boolean> +-- !query 38 output +true + + +-- !query 39 +SELECT '2' >= cast(1 as tinyint) FROM t +-- !query 39 schema +struct<(CAST(2 AS TINYINT) >= CAST(1 AS TINYINT)):boolean> +-- !query 39 output +true + + +-- !query 40 +SELECT '2' < cast(1 as tinyint) FROM t +-- !query 40 schema +struct<(CAST(2 AS TINYINT) < CAST(1 AS TINYINT)):boolean> +-- !query 40 output +false + + +-- !query 41 +SELECT '2' <= cast(1 as tinyint) FROM t +-- !query 41 schema +struct<(CAST(2 AS TINYINT) <= CAST(1 AS TINYINT)):boolean> +-- !query 41 output +false + + +-- !query 42 +SELECT '2' <> cast(1 as tinyint) FROM t +-- !query 42 schema +struct<(NOT (CAST(2 AS TINYINT) = CAST(1 AS TINYINT))):boolean> +-- !query 42 output +true + + +-- !query 43 +SELECT cast(null as string) = cast(1 as tinyint) FROM t +-- !query 43 schema +struct<(CAST(CAST(NULL AS STRING) AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 43 output +NULL + + +-- !query 44 +SELECT cast(null as string) > cast(1 as tinyint) FROM t +-- !query 44 schema +struct<(CAST(CAST(NULL AS STRING) AS TINYINT) > CAST(1 AS TINYINT)):boolean> +-- !query 44 output +NULL + + +-- !query 45 +SELECT cast(null as string) >= cast(1 as tinyint) FROM t +-- !query 45 schema +struct<(CAST(CAST(NULL AS STRING) AS TINYINT) >= CAST(1 AS TINYINT)):boolean> +-- !query 45 output +NULL + + +-- !query 46 +SELECT cast(null as string) < cast(1 as tinyint) FROM t +-- !query 46 schema +struct<(CAST(CAST(NULL AS STRING) AS TINYINT) < CAST(1 AS TINYINT)):boolean> +-- !query 46 output +NULL + + +-- !query 47 +SELECT cast(null as string) <= cast(1 as tinyint) FROM t +-- !query 47 schema +struct<(CAST(CAST(NULL AS STRING) AS TINYINT) <= CAST(1 AS TINYINT)):boolean> +-- !query 47 output +NULL + + +-- !query 48 +SELECT cast(null as string) <> cast(1 as tinyint) FROM t +-- !query 48 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS TINYINT) = CAST(1 AS TINYINT))):boolean> +-- !query 48 output +NULL + + +-- !query 49 +SELECT cast(1 as smallint) = '1' FROM t +-- !query 49 schema +struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 49 output +true + + +-- !query 50 +SELECT cast(1 as smallint) > '2' FROM t +-- !query 50 schema +struct<(CAST(1 AS SMALLINT) > CAST(2 AS SMALLINT)):boolean> +-- !query 50 output +false + + +-- !query 51 +SELECT cast(1 as smallint) >= '2' FROM t +-- !query 51 schema +struct<(CAST(1 AS SMALLINT) >= CAST(2 AS SMALLINT)):boolean> +-- !query 51 output +false + + +-- !query 52 +SELECT cast(1 as smallint) < '2' FROM t +-- !query 52 schema +struct<(CAST(1 AS SMALLINT) < CAST(2 AS SMALLINT)):boolean> +-- !query 52 output +true + + +-- !query 53 +SELECT cast(1 as smallint) <= '2' FROM t +-- !query 53 schema +struct<(CAST(1 AS SMALLINT) <= CAST(2 AS SMALLINT)):boolean> +-- !query 53 output +true + + +-- !query 54 +SELECT cast(1 as smallint) <> '2' FROM t +-- !query 54 schema +struct<(NOT (CAST(1 AS SMALLINT) = CAST(2 AS SMALLINT))):boolean> +-- !query 54 output +true + + +-- !query 55 +SELECT cast(1 as smallint) = cast(null as string) FROM t +-- !query 55 schema +struct<(CAST(1 AS SMALLINT) = CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean> +-- !query 55 output +NULL + + +-- !query 56 +SELECT cast(1 as smallint) > cast(null as string) FROM t +-- !query 56 schema +struct<(CAST(1 AS SMALLINT) > CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean> +-- !query 56 output +NULL + + +-- !query 57 +SELECT cast(1 as smallint) >= cast(null as string) FROM t +-- !query 57 schema +struct<(CAST(1 AS SMALLINT) >= CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean> +-- !query 57 output +NULL + + +-- !query 58 +SELECT cast(1 as smallint) < cast(null as string) FROM t +-- !query 58 schema +struct<(CAST(1 AS SMALLINT) < CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean> +-- !query 58 output +NULL + + +-- !query 59 +SELECT cast(1 as smallint) <= cast(null as string) FROM t +-- !query 59 schema +struct<(CAST(1 AS SMALLINT) <= CAST(CAST(NULL AS STRING) AS SMALLINT)):boolean> +-- !query 59 output +NULL + + +-- !query 60 +SELECT cast(1 as smallint) <> cast(null as string) FROM t +-- !query 60 schema +struct<(NOT (CAST(1 AS SMALLINT) = CAST(CAST(NULL AS STRING) AS SMALLINT))):boolean> +-- !query 60 output +NULL + + +-- !query 61 +SELECT '1' = cast(1 as smallint) FROM t +-- !query 61 schema +struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 61 output +true + + +-- !query 62 +SELECT '2' > cast(1 as smallint) FROM t +-- !query 62 schema +struct<(CAST(2 AS SMALLINT) > CAST(1 AS SMALLINT)):boolean> +-- !query 62 output +true + + +-- !query 63 +SELECT '2' >= cast(1 as smallint) FROM t +-- !query 63 schema +struct<(CAST(2 AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean> +-- !query 63 output +true + + +-- !query 64 +SELECT '2' < cast(1 as smallint) FROM t +-- !query 64 schema +struct<(CAST(2 AS SMALLINT) < CAST(1 AS SMALLINT)):boolean> +-- !query 64 output +false + + +-- !query 65 +SELECT '2' <= cast(1 as smallint) FROM t +-- !query 65 schema +struct<(CAST(2 AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean> +-- !query 65 output +false + + +-- !query 66 +SELECT '2' <> cast(1 as smallint) FROM t +-- !query 66 schema +struct<(NOT (CAST(2 AS SMALLINT) = CAST(1 AS SMALLINT))):boolean> +-- !query 66 output +true + + +-- !query 67 +SELECT cast(null as string) = cast(1 as smallint) FROM t +-- !query 67 schema +struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 67 output +NULL + + +-- !query 68 +SELECT cast(null as string) > cast(1 as smallint) FROM t +-- !query 68 schema +struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) > CAST(1 AS SMALLINT)):boolean> +-- !query 68 output +NULL + + +-- !query 69 +SELECT cast(null as string) >= cast(1 as smallint) FROM t +-- !query 69 schema +struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean> +-- !query 69 output +NULL + + +-- !query 70 +SELECT cast(null as string) < cast(1 as smallint) FROM t +-- !query 70 schema +struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) < CAST(1 AS SMALLINT)):boolean> +-- !query 70 output +NULL + + +-- !query 71 +SELECT cast(null as string) <= cast(1 as smallint) FROM t +-- !query 71 schema +struct<(CAST(CAST(NULL AS STRING) AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean> +-- !query 71 output +NULL + + +-- !query 72 +SELECT cast(null as string) <> cast(1 as smallint) FROM t +-- !query 72 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS SMALLINT) = CAST(1 AS SMALLINT))):boolean> +-- !query 72 output +NULL + + +-- !query 73 +SELECT cast(1 as int) = '1' FROM t +-- !query 73 schema +struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean> +-- !query 73 output +true + + +-- !query 74 +SELECT cast(1 as int) > '2' FROM t +-- !query 74 schema +struct<(CAST(1 AS INT) > CAST(2 AS INT)):boolean> +-- !query 74 output +false + + +-- !query 75 +SELECT cast(1 as int) >= '2' FROM t +-- !query 75 schema +struct<(CAST(1 AS INT) >= CAST(2 AS INT)):boolean> +-- !query 75 output +false + + +-- !query 76 +SELECT cast(1 as int) < '2' FROM t +-- !query 76 schema +struct<(CAST(1 AS INT) < CAST(2 AS INT)):boolean> +-- !query 76 output +true + + +-- !query 77 +SELECT cast(1 as int) <= '2' FROM t +-- !query 77 schema +struct<(CAST(1 AS INT) <= CAST(2 AS INT)):boolean> +-- !query 77 output +true + + +-- !query 78 +SELECT cast(1 as int) <> '2' FROM t +-- !query 78 schema +struct<(NOT (CAST(1 AS INT) = CAST(2 AS INT))):boolean> +-- !query 78 output +true + + +-- !query 79 +SELECT cast(1 as int) = cast(null as string) FROM t +-- !query 79 schema +struct<(CAST(1 AS INT) = CAST(CAST(NULL AS STRING) AS INT)):boolean> +-- !query 79 output +NULL + + +-- !query 80 +SELECT cast(1 as int) > cast(null as string) FROM t +-- !query 80 schema +struct<(CAST(1 AS INT) > CAST(CAST(NULL AS STRING) AS INT)):boolean> +-- !query 80 output +NULL + + +-- !query 81 +SELECT cast(1 as int) >= cast(null as string) FROM t +-- !query 81 schema +struct<(CAST(1 AS INT) >= CAST(CAST(NULL AS STRING) AS INT)):boolean> +-- !query 81 output +NULL + + +-- !query 82 +SELECT cast(1 as int) < cast(null as string) FROM t +-- !query 82 schema +struct<(CAST(1 AS INT) < CAST(CAST(NULL AS STRING) AS INT)):boolean> +-- !query 82 output +NULL + + +-- !query 83 +SELECT cast(1 as int) <= cast(null as string) FROM t +-- !query 83 schema +struct<(CAST(1 AS INT) <= CAST(CAST(NULL AS STRING) AS INT)):boolean> +-- !query 83 output +NULL + + +-- !query 84 +SELECT cast(1 as int) <> cast(null as string) FROM t +-- !query 84 schema +struct<(NOT (CAST(1 AS INT) = CAST(CAST(NULL AS STRING) AS INT))):boolean> +-- !query 84 output +NULL + + +-- !query 85 +SELECT '1' = cast(1 as int) FROM t +-- !query 85 schema +struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean> +-- !query 85 output +true + + +-- !query 86 +SELECT '2' > cast(1 as int) FROM t +-- !query 86 schema +struct<(CAST(2 AS INT) > CAST(1 AS INT)):boolean> +-- !query 86 output +true + + +-- !query 87 +SELECT '2' >= cast(1 as int) FROM t +-- !query 87 schema +struct<(CAST(2 AS INT) >= CAST(1 AS INT)):boolean> +-- !query 87 output +true + + +-- !query 88 +SELECT '2' < cast(1 as int) FROM t +-- !query 88 schema +struct<(CAST(2 AS INT) < CAST(1 AS INT)):boolean> +-- !query 88 output +false + + +-- !query 89 +SELECT '2' <> cast(1 as int) FROM t +-- !query 89 schema +struct<(NOT (CAST(2 AS INT) = CAST(1 AS INT))):boolean> +-- !query 89 output +true + + +-- !query 90 +SELECT '2' <= cast(1 as int) FROM t +-- !query 90 schema +struct<(CAST(2 AS INT) <= CAST(1 AS INT)):boolean> +-- !query 90 output +false + + +-- !query 91 +SELECT cast(null as string) = cast(1 as int) FROM t +-- !query 91 schema +struct<(CAST(CAST(NULL AS STRING) AS INT) = CAST(1 AS INT)):boolean> +-- !query 91 output +NULL + + +-- !query 92 +SELECT cast(null as string) > cast(1 as int) FROM t +-- !query 92 schema +struct<(CAST(CAST(NULL AS STRING) AS INT) > CAST(1 AS INT)):boolean> +-- !query 92 output +NULL + + +-- !query 93 +SELECT cast(null as string) >= cast(1 as int) FROM t +-- !query 93 schema +struct<(CAST(CAST(NULL AS STRING) AS INT) >= CAST(1 AS INT)):boolean> +-- !query 93 output +NULL + + +-- !query 94 +SELECT cast(null as string) < cast(1 as int) FROM t +-- !query 94 schema +struct<(CAST(CAST(NULL AS STRING) AS INT) < CAST(1 AS INT)):boolean> +-- !query 94 output +NULL + + +-- !query 95 +SELECT cast(null as string) <> cast(1 as int) FROM t +-- !query 95 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS INT) = CAST(1 AS INT))):boolean> +-- !query 95 output +NULL + + +-- !query 96 +SELECT cast(null as string) <= cast(1 as int) FROM t +-- !query 96 schema +struct<(CAST(CAST(NULL AS STRING) AS INT) <= CAST(1 AS INT)):boolean> +-- !query 96 output +NULL + + +-- !query 97 +SELECT cast(1 as bigint) = '1' FROM t +-- !query 97 schema +struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 97 output +true + + +-- !query 98 +SELECT cast(1 as bigint) > '2' FROM t +-- !query 98 schema +struct<(CAST(1 AS BIGINT) > CAST(2 AS BIGINT)):boolean> +-- !query 98 output +false + + +-- !query 99 +SELECT cast(1 as bigint) >= '2' FROM t +-- !query 99 schema +struct<(CAST(1 AS BIGINT) >= CAST(2 AS BIGINT)):boolean> +-- !query 99 output +false + + +-- !query 100 +SELECT cast(1 as bigint) < '2' FROM t +-- !query 100 schema +struct<(CAST(1 AS BIGINT) < CAST(2 AS BIGINT)):boolean> +-- !query 100 output +true + + +-- !query 101 +SELECT cast(1 as bigint) <= '2' FROM t +-- !query 101 schema +struct<(CAST(1 AS BIGINT) <= CAST(2 AS BIGINT)):boolean> +-- !query 101 output +true + + +-- !query 102 +SELECT cast(1 as bigint) <> '2' FROM t +-- !query 102 schema +struct<(NOT (CAST(1 AS BIGINT) = CAST(2 AS BIGINT))):boolean> +-- !query 102 output +true + + +-- !query 103 +SELECT cast(1 as bigint) = cast(null as string) FROM t +-- !query 103 schema +struct<(CAST(1 AS BIGINT) = CAST(CAST(NULL AS STRING) AS BIGINT)):boolean> +-- !query 103 output +NULL + + +-- !query 104 +SELECT cast(1 as bigint) > cast(null as string) FROM t +-- !query 104 schema +struct<(CAST(1 AS BIGINT) > CAST(CAST(NULL AS STRING) AS BIGINT)):boolean> +-- !query 104 output +NULL + + +-- !query 105 +SELECT cast(1 as bigint) >= cast(null as string) FROM t +-- !query 105 schema +struct<(CAST(1 AS BIGINT) >= CAST(CAST(NULL AS STRING) AS BIGINT)):boolean> +-- !query 105 output +NULL + + +-- !query 106 +SELECT cast(1 as bigint) < cast(null as string) FROM t +-- !query 106 schema +struct<(CAST(1 AS BIGINT) < CAST(CAST(NULL AS STRING) AS BIGINT)):boolean> +-- !query 106 output +NULL + + +-- !query 107 +SELECT cast(1 as bigint) <= cast(null as string) FROM t +-- !query 107 schema +struct<(CAST(1 AS BIGINT) <= CAST(CAST(NULL AS STRING) AS BIGINT)):boolean> +-- !query 107 output +NULL + + +-- !query 108 +SELECT cast(1 as bigint) <> cast(null as string) FROM t +-- !query 108 schema +struct<(NOT (CAST(1 AS BIGINT) = CAST(CAST(NULL AS STRING) AS BIGINT))):boolean> +-- !query 108 output +NULL + + +-- !query 109 +SELECT '1' = cast(1 as bigint) FROM t +-- !query 109 schema +struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 109 output +true + + +-- !query 110 +SELECT '2' > cast(1 as bigint) FROM t +-- !query 110 schema +struct<(CAST(2 AS BIGINT) > CAST(1 AS BIGINT)):boolean> +-- !query 110 output +true + + +-- !query 111 +SELECT '2' >= cast(1 as bigint) FROM t +-- !query 111 schema +struct<(CAST(2 AS BIGINT) >= CAST(1 AS BIGINT)):boolean> +-- !query 111 output +true + + +-- !query 112 +SELECT '2' < cast(1 as bigint) FROM t +-- !query 112 schema +struct<(CAST(2 AS BIGINT) < CAST(1 AS BIGINT)):boolean> +-- !query 112 output +false + + +-- !query 113 +SELECT '2' <= cast(1 as bigint) FROM t +-- !query 113 schema +struct<(CAST(2 AS BIGINT) <= CAST(1 AS BIGINT)):boolean> +-- !query 113 output +false + + +-- !query 114 +SELECT '2' <> cast(1 as bigint) FROM t +-- !query 114 schema +struct<(NOT (CAST(2 AS BIGINT) = CAST(1 AS BIGINT))):boolean> +-- !query 114 output +true + + +-- !query 115 +SELECT cast(null as string) = cast(1 as bigint) FROM t +-- !query 115 schema +struct<(CAST(CAST(NULL AS STRING) AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 115 output +NULL + + +-- !query 116 +SELECT cast(null as string) > cast(1 as bigint) FROM t +-- !query 116 schema +struct<(CAST(CAST(NULL AS STRING) AS BIGINT) > CAST(1 AS BIGINT)):boolean> +-- !query 116 output +NULL + + +-- !query 117 +SELECT cast(null as string) >= cast(1 as bigint) FROM t +-- !query 117 schema +struct<(CAST(CAST(NULL AS STRING) AS BIGINT) >= CAST(1 AS BIGINT)):boolean> +-- !query 117 output +NULL + + +-- !query 118 +SELECT cast(null as string) < cast(1 as bigint) FROM t +-- !query 118 schema +struct<(CAST(CAST(NULL AS STRING) AS BIGINT) < CAST(1 AS BIGINT)):boolean> +-- !query 118 output +NULL + + +-- !query 119 +SELECT cast(null as string) <= cast(1 as bigint) FROM t +-- !query 119 schema +struct<(CAST(CAST(NULL AS STRING) AS BIGINT) <= CAST(1 AS BIGINT)):boolean> +-- !query 119 output +NULL + + +-- !query 120 +SELECT cast(null as string) <> cast(1 as bigint) FROM t +-- !query 120 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS BIGINT) = CAST(1 AS BIGINT))):boolean> +-- !query 120 output +NULL + + +-- !query 121 +SELECT cast(1 as decimal(10, 0)) = '1' FROM t +-- !query 121 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 121 output +true + + +-- !query 122 +SELECT cast(1 as decimal(10, 0)) > '2' FROM t +-- !query 122 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(2 AS DOUBLE)):boolean> +-- !query 122 output +false + + +-- !query 123 +SELECT cast(1 as decimal(10, 0)) >= '2' FROM t +-- !query 123 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(2 AS DOUBLE)):boolean> +-- !query 123 output +false + + +-- !query 124 +SELECT cast(1 as decimal(10, 0)) < '2' FROM t +-- !query 124 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(2 AS DOUBLE)):boolean> +-- !query 124 output +true + + +-- !query 125 +SELECT cast(1 as decimal(10, 0)) <> '2' FROM t +-- !query 125 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(2 AS DOUBLE))):boolean> +-- !query 125 output +true + + +-- !query 126 +SELECT cast(1 as decimal(10, 0)) <= '2' FROM t +-- !query 126 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(2 AS DOUBLE)):boolean> +-- !query 126 output +true + + +-- !query 127 +SELECT cast(1 as decimal(10, 0)) = cast(null as string) FROM t +-- !query 127 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 127 output +NULL + + +-- !query 128 +SELECT cast(1 as decimal(10, 0)) > cast(null as string) FROM t +-- !query 128 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 128 output +NULL + + +-- !query 129 +SELECT cast(1 as decimal(10, 0)) >= cast(null as string) FROM t +-- !query 129 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 129 output +NULL + + +-- !query 130 +SELECT cast(1 as decimal(10, 0)) < cast(null as string) FROM t +-- !query 130 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 130 output +NULL + + +-- !query 131 +SELECT cast(1 as decimal(10, 0)) <> cast(null as string) FROM t +-- !query 131 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE))):boolean> +-- !query 131 output +NULL + + +-- !query 132 +SELECT cast(1 as decimal(10, 0)) <= cast(null as string) FROM t +-- !query 132 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 132 output +NULL + + +-- !query 133 +SELECT '1' = cast(1 as decimal(10, 0)) FROM t +-- !query 133 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 133 output +true + + +-- !query 134 +SELECT '2' > cast(1 as decimal(10, 0)) FROM t +-- !query 134 schema +struct<(CAST(2 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 134 output +true + + +-- !query 135 +SELECT '2' >= cast(1 as decimal(10, 0)) FROM t +-- !query 135 schema +struct<(CAST(2 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 135 output +true + + +-- !query 136 +SELECT '2' < cast(1 as decimal(10, 0)) FROM t +-- !query 136 schema +struct<(CAST(2 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 136 output +false + + +-- !query 137 +SELECT '2' <= cast(1 as decimal(10, 0)) FROM t +-- !query 137 schema +struct<(CAST(2 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 137 output +false + + +-- !query 138 +SELECT '2' <> cast(1 as decimal(10, 0)) FROM t +-- !query 138 schema +struct<(NOT (CAST(2 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 138 output +true + + +-- !query 139 +SELECT cast(null as string) = cast(1 as decimal(10, 0)) FROM t +-- !query 139 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 139 output +NULL + + +-- !query 140 +SELECT cast(null as string) > cast(1 as decimal(10, 0)) FROM t +-- !query 140 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 140 output +NULL + + +-- !query 141 +SELECT cast(null as string) >= cast(1 as decimal(10, 0)) FROM t +-- !query 141 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 141 output +NULL + + +-- !query 142 +SELECT cast(null as string) < cast(1 as decimal(10, 0)) FROM t +-- !query 142 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 142 output +NULL + + +-- !query 143 +SELECT cast(null as string) <= cast(1 as decimal(10, 0)) FROM t +-- !query 143 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 143 output +NULL + + +-- !query 144 +SELECT cast(null as string) <> cast(1 as decimal(10, 0)) FROM t +-- !query 144 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 144 output +NULL + + +-- !query 145 +SELECT cast(1 as double) = '1' FROM t +-- !query 145 schema +struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 145 output +true + + +-- !query 146 +SELECT cast(1 as double) > '2' FROM t +-- !query 146 schema +struct<(CAST(1 AS DOUBLE) > CAST(2 AS DOUBLE)):boolean> +-- !query 146 output +false + + +-- !query 147 +SELECT cast(1 as double) >= '2' FROM t +-- !query 147 schema +struct<(CAST(1 AS DOUBLE) >= CAST(2 AS DOUBLE)):boolean> +-- !query 147 output +false + + +-- !query 148 +SELECT cast(1 as double) < '2' FROM t +-- !query 148 schema +struct<(CAST(1 AS DOUBLE) < CAST(2 AS DOUBLE)):boolean> +-- !query 148 output +true + + +-- !query 149 +SELECT cast(1 as double) <= '2' FROM t +-- !query 149 schema +struct<(CAST(1 AS DOUBLE) <= CAST(2 AS DOUBLE)):boolean> +-- !query 149 output +true + + +-- !query 150 +SELECT cast(1 as double) <> '2' FROM t +-- !query 150 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(2 AS DOUBLE))):boolean> +-- !query 150 output +true + + +-- !query 151 +SELECT cast(1 as double) = cast(null as string) FROM t +-- !query 151 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 151 output +NULL + + +-- !query 152 +SELECT cast(1 as double) > cast(null as string) FROM t +-- !query 152 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 152 output +NULL + + +-- !query 153 +SELECT cast(1 as double) >= cast(null as string) FROM t +-- !query 153 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 153 output +NULL + + +-- !query 154 +SELECT cast(1 as double) < cast(null as string) FROM t +-- !query 154 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 154 output +NULL + + +-- !query 155 +SELECT cast(1 as double) <= cast(null as string) FROM t +-- !query 155 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(NULL AS STRING) AS DOUBLE)):boolean> +-- !query 155 output +NULL + + +-- !query 156 +SELECT cast(1 as double) <> cast(null as string) FROM t +-- !query 156 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(NULL AS STRING) AS DOUBLE))):boolean> +-- !query 156 output +NULL + + +-- !query 157 +SELECT '1' = cast(1 as double) FROM t +-- !query 157 schema +struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 157 output +true + + +-- !query 158 +SELECT '2' > cast(1 as double) FROM t +-- !query 158 schema +struct<(CAST(2 AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 158 output +true + + +-- !query 159 +SELECT '2' >= cast(1 as double) FROM t +-- !query 159 schema +struct<(CAST(2 AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 159 output +true + + +-- !query 160 +SELECT '2' < cast(1 as double) FROM t +-- !query 160 schema +struct<(CAST(2 AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 160 output +false + + +-- !query 161 +SELECT '2' <= cast(1 as double) FROM t +-- !query 161 schema +struct<(CAST(2 AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 161 output +false + + +-- !query 162 +SELECT '2' <> cast(1 as double) FROM t +-- !query 162 schema +struct<(NOT (CAST(2 AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 162 output +true + + +-- !query 163 +SELECT cast(null as string) = cast(1 as double) FROM t +-- !query 163 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 163 output +NULL + + +-- !query 164 +SELECT cast(null as string) > cast(1 as double) FROM t +-- !query 164 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 164 output +NULL + + +-- !query 165 +SELECT cast(null as string) >= cast(1 as double) FROM t +-- !query 165 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 165 output +NULL + + +-- !query 166 +SELECT cast(null as string) < cast(1 as double) FROM t +-- !query 166 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 166 output +NULL + + +-- !query 167 +SELECT cast(null as string) <= cast(1 as double) FROM t +-- !query 167 schema +struct<(CAST(CAST(NULL AS STRING) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 167 output +NULL + + +-- !query 168 +SELECT cast(null as string) <> cast(1 as double) FROM t +-- !query 168 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 168 output +NULL + + +-- !query 169 +SELECT cast(1 as float) = '1' FROM t +-- !query 169 schema +struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 169 output +true + + +-- !query 170 +SELECT cast(1 as float) > '2' FROM t +-- !query 170 schema +struct<(CAST(1 AS FLOAT) > CAST(2 AS FLOAT)):boolean> +-- !query 170 output +false + + +-- !query 171 +SELECT cast(1 as float) >= '2' FROM t +-- !query 171 schema +struct<(CAST(1 AS FLOAT) >= CAST(2 AS FLOAT)):boolean> +-- !query 171 output +false + + +-- !query 172 +SELECT cast(1 as float) < '2' FROM t +-- !query 172 schema +struct<(CAST(1 AS FLOAT) < CAST(2 AS FLOAT)):boolean> +-- !query 172 output +true + + +-- !query 173 +SELECT cast(1 as float) <= '2' FROM t +-- !query 173 schema +struct<(CAST(1 AS FLOAT) <= CAST(2 AS FLOAT)):boolean> +-- !query 173 output +true + + +-- !query 174 +SELECT cast(1 as float) <> '2' FROM t +-- !query 174 schema +struct<(NOT (CAST(1 AS FLOAT) = CAST(2 AS FLOAT))):boolean> +-- !query 174 output +true + + +-- !query 175 +SELECT cast(1 as float) = cast(null as string) FROM t +-- !query 175 schema +struct<(CAST(1 AS FLOAT) = CAST(CAST(NULL AS STRING) AS FLOAT)):boolean> +-- !query 175 output +NULL + + +-- !query 176 +SELECT cast(1 as float) > cast(null as string) FROM t +-- !query 176 schema +struct<(CAST(1 AS FLOAT) > CAST(CAST(NULL AS STRING) AS FLOAT)):boolean> +-- !query 176 output +NULL + + +-- !query 177 +SELECT cast(1 as float) >= cast(null as string) FROM t +-- !query 177 schema +struct<(CAST(1 AS FLOAT) >= CAST(CAST(NULL AS STRING) AS FLOAT)):boolean> +-- !query 177 output +NULL + + +-- !query 178 +SELECT cast(1 as float) < cast(null as string) FROM t +-- !query 178 schema +struct<(CAST(1 AS FLOAT) < CAST(CAST(NULL AS STRING) AS FLOAT)):boolean> +-- !query 178 output +NULL + + +-- !query 179 +SELECT cast(1 as float) <= cast(null as string) FROM t +-- !query 179 schema +struct<(CAST(1 AS FLOAT) <= CAST(CAST(NULL AS STRING) AS FLOAT)):boolean> +-- !query 179 output +NULL + + +-- !query 180 +SELECT cast(1 as float) <> cast(null as string) FROM t +-- !query 180 schema +struct<(NOT (CAST(1 AS FLOAT) = CAST(CAST(NULL AS STRING) AS FLOAT))):boolean> +-- !query 180 output +NULL + + +-- !query 181 +SELECT '1' = cast(1 as float) FROM t +-- !query 181 schema +struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 181 output +true + + +-- !query 182 +SELECT '2' > cast(1 as float) FROM t +-- !query 182 schema +struct<(CAST(2 AS FLOAT) > CAST(1 AS FLOAT)):boolean> +-- !query 182 output +true + + +-- !query 183 +SELECT '2' >= cast(1 as float) FROM t +-- !query 183 schema +struct<(CAST(2 AS FLOAT) >= CAST(1 AS FLOAT)):boolean> +-- !query 183 output +true + + +-- !query 184 +SELECT '2' < cast(1 as float) FROM t +-- !query 184 schema +struct<(CAST(2 AS FLOAT) < CAST(1 AS FLOAT)):boolean> +-- !query 184 output +false + + +-- !query 185 +SELECT '2' <= cast(1 as float) FROM t +-- !query 185 schema +struct<(CAST(2 AS FLOAT) <= CAST(1 AS FLOAT)):boolean> +-- !query 185 output +false + + +-- !query 186 +SELECT '2' <> cast(1 as float) FROM t +-- !query 186 schema +struct<(NOT (CAST(2 AS FLOAT) = CAST(1 AS FLOAT))):boolean> +-- !query 186 output +true + + +-- !query 187 +SELECT cast(null as string) = cast(1 as float) FROM t +-- !query 187 schema +struct<(CAST(CAST(NULL AS STRING) AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 187 output +NULL + + +-- !query 188 +SELECT cast(null as string) > cast(1 as float) FROM t +-- !query 188 schema +struct<(CAST(CAST(NULL AS STRING) AS FLOAT) > CAST(1 AS FLOAT)):boolean> +-- !query 188 output +NULL + + +-- !query 189 +SELECT cast(null as string) >= cast(1 as float) FROM t +-- !query 189 schema +struct<(CAST(CAST(NULL AS STRING) AS FLOAT) >= CAST(1 AS FLOAT)):boolean> +-- !query 189 output +NULL + + +-- !query 190 +SELECT cast(null as string) < cast(1 as float) FROM t +-- !query 190 schema +struct<(CAST(CAST(NULL AS STRING) AS FLOAT) < CAST(1 AS FLOAT)):boolean> +-- !query 190 output +NULL + + +-- !query 191 +SELECT cast(null as string) <= cast(1 as float) FROM t +-- !query 191 schema +struct<(CAST(CAST(NULL AS STRING) AS FLOAT) <= CAST(1 AS FLOAT)):boolean> +-- !query 191 output +NULL + + +-- !query 192 +SELECT cast(null as string) <> cast(1 as float) FROM t +-- !query 192 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS FLOAT) = CAST(1 AS FLOAT))):boolean> +-- !query 192 output +NULL + + +-- !query 193 +SELECT '1996-09-09' = date('1996-09-09') FROM t +-- !query 193 schema +struct<(1996-09-09 = CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 193 output +true + + +-- !query 194 +SELECT '1996-9-10' > date('1996-09-09') FROM t +-- !query 194 schema +struct<(1996-9-10 > CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 194 output +true + + +-- !query 195 +SELECT '1996-9-10' >= date('1996-09-09') FROM t +-- !query 195 schema +struct<(1996-9-10 >= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 195 output +true + + +-- !query 196 +SELECT '1996-9-10' < date('1996-09-09') FROM t +-- !query 196 schema +struct<(1996-9-10 < CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 196 output +false + + +-- !query 197 +SELECT '1996-9-10' <= date('1996-09-09') FROM t +-- !query 197 schema +struct<(1996-9-10 <= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 197 output +false + + +-- !query 198 +SELECT '1996-9-10' <> date('1996-09-09') FROM t +-- !query 198 schema +struct<(NOT (1996-9-10 = CAST(CAST(1996-09-09 AS DATE) AS STRING))):boolean> +-- !query 198 output +true + + +-- !query 199 +SELECT cast(null as string) = date('1996-09-09') FROM t +-- !query 199 schema +struct<(CAST(NULL AS STRING) = CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 199 output +NULL + + +-- !query 200 +SELECT cast(null as string)> date('1996-09-09') FROM t +-- !query 200 schema +struct<(CAST(NULL AS STRING) > CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 200 output +NULL + + +-- !query 201 +SELECT cast(null as string)>= date('1996-09-09') FROM t +-- !query 201 schema +struct<(CAST(NULL AS STRING) >= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 201 output +NULL + + +-- !query 202 +SELECT cast(null as string)< date('1996-09-09') FROM t +-- !query 202 schema +struct<(CAST(NULL AS STRING) < CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 202 output +NULL + + +-- !query 203 +SELECT cast(null as string)<= date('1996-09-09') FROM t +-- !query 203 schema +struct<(CAST(NULL AS STRING) <= CAST(CAST(1996-09-09 AS DATE) AS STRING)):boolean> +-- !query 203 output +NULL + + +-- !query 204 +SELECT cast(null as string)<> date('1996-09-09') FROM t +-- !query 204 schema +struct<(NOT (CAST(NULL AS STRING) = CAST(CAST(1996-09-09 AS DATE) AS STRING))):boolean> +-- !query 204 output +NULL + + +-- !query 205 +SELECT date('1996-09-09') = '1996-09-09' FROM t +-- !query 205 schema +struct<(CAST(CAST(1996-09-09 AS DATE) AS STRING) = 1996-09-09):boolean> +-- !query 205 output +true + + +-- !query 206 +SELECT date('1996-9-10') > '1996-09-09' FROM t +-- !query 206 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) > 1996-09-09):boolean> +-- !query 206 output +true + + +-- !query 207 +SELECT date('1996-9-10') >= '1996-09-09' FROM t +-- !query 207 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) >= 1996-09-09):boolean> +-- !query 207 output +true + + +-- !query 208 +SELECT date('1996-9-10') < '1996-09-09' FROM t +-- !query 208 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) < 1996-09-09):boolean> +-- !query 208 output +false + + +-- !query 209 +SELECT date('1996-9-10') <= '1996-09-09' FROM t +-- !query 209 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) <= 1996-09-09):boolean> +-- !query 209 output +false + + +-- !query 210 +SELECT date('1996-9-10') <> '1996-09-09' FROM t +-- !query 210 schema +struct<(NOT (CAST(CAST(1996-9-10 AS DATE) AS STRING) = 1996-09-09)):boolean> +-- !query 210 output +true + + +-- !query 211 +SELECT date('1996-09-09') = cast(null as string) FROM t +-- !query 211 schema +struct<(CAST(CAST(1996-09-09 AS DATE) AS STRING) = CAST(NULL AS STRING)):boolean> +-- !query 211 output +NULL + + +-- !query 212 +SELECT date('1996-9-10') > cast(null as string) FROM t +-- !query 212 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) > CAST(NULL AS STRING)):boolean> +-- !query 212 output +NULL + + +-- !query 213 +SELECT date('1996-9-10') >= cast(null as string) FROM t +-- !query 213 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) >= CAST(NULL AS STRING)):boolean> +-- !query 213 output +NULL + + +-- !query 214 +SELECT date('1996-9-10') < cast(null as string) FROM t +-- !query 214 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) < CAST(NULL AS STRING)):boolean> +-- !query 214 output +NULL + + +-- !query 215 +SELECT date('1996-9-10') <= cast(null as string) FROM t +-- !query 215 schema +struct<(CAST(CAST(1996-9-10 AS DATE) AS STRING) <= CAST(NULL AS STRING)):boolean> +-- !query 215 output +NULL + + +-- !query 216 +SELECT date('1996-9-10') <> cast(null as string) FROM t +-- !query 216 schema +struct<(NOT (CAST(CAST(1996-9-10 AS DATE) AS STRING) = CAST(NULL AS STRING))):boolean> +-- !query 216 output +NULL + + +-- !query 217 +SELECT '1996-09-09 12:12:12.4' = timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 217 schema +struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean> +-- !query 217 output +true + + +-- !query 218 +SELECT '1996-09-09 12:12:12.5' > timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 218 schema +struct<(1996-09-09 12:12:12.5 > CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 218 output +true + + +-- !query 219 +SELECT '1996-09-09 12:12:12.5' >= timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 219 schema +struct<(1996-09-09 12:12:12.5 >= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 219 output +true + + +-- !query 220 +SELECT '1996-09-09 12:12:12.5' < timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 220 schema +struct<(1996-09-09 12:12:12.5 < CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 220 output +false + + +-- !query 221 +SELECT '1996-09-09 12:12:12.5' <= timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 221 schema +struct<(1996-09-09 12:12:12.5 <= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 221 output +false + + +-- !query 222 +SELECT '1996-09-09 12:12:12.5' <> timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 222 schema +struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean> +-- !query 222 output +true + + +-- !query 223 +SELECT cast(null as string) = timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 223 schema +struct<(CAST(CAST(NULL AS STRING) AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean> +-- !query 223 output +NULL + + +-- !query 224 +SELECT cast(null as string) > timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 224 schema +struct<(CAST(NULL AS STRING) > CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 224 output +NULL + + +-- !query 225 +SELECT cast(null as string) >= timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 225 schema +struct<(CAST(NULL AS STRING) >= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 225 output +NULL + + +-- !query 226 +SELECT cast(null as string) < timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 226 schema +struct<(CAST(NULL AS STRING) < CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 226 output +NULL + + +-- !query 227 +SELECT cast(null as string) <= timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 227 schema +struct<(CAST(NULL AS STRING) <= CAST(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) AS STRING)):boolean> +-- !query 227 output +NULL + + +-- !query 228 +SELECT cast(null as string) <> timestamp('1996-09-09 12:12:12.4') FROM t +-- !query 228 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean> +-- !query 228 output +NULL + + +-- !query 229 +SELECT timestamp('1996-09-09 12:12:12.4' )= '1996-09-09 12:12:12.4' FROM t +-- !query 229 schema +struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP)):boolean> +-- !query 229 output +true + + +-- !query 230 +SELECT timestamp('1996-09-09 12:12:12.5' )> '1996-09-09 12:12:12.4' FROM t +-- !query 230 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) > 1996-09-09 12:12:12.4):boolean> +-- !query 230 output +true + + +-- !query 231 +SELECT timestamp('1996-09-09 12:12:12.5' )>= '1996-09-09 12:12:12.4' FROM t +-- !query 231 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) >= 1996-09-09 12:12:12.4):boolean> +-- !query 231 output +true + + +-- !query 232 +SELECT timestamp('1996-09-09 12:12:12.5' )< '1996-09-09 12:12:12.4' FROM t +-- !query 232 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) < 1996-09-09 12:12:12.4):boolean> +-- !query 232 output +false + + +-- !query 233 +SELECT timestamp('1996-09-09 12:12:12.5' )<= '1996-09-09 12:12:12.4' FROM t +-- !query 233 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) <= 1996-09-09 12:12:12.4):boolean> +-- !query 233 output +false + + +-- !query 234 +SELECT timestamp('1996-09-09 12:12:12.5' )<> '1996-09-09 12:12:12.4' FROM t +-- !query 234 schema +struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(1996-09-09 12:12:12.4 AS TIMESTAMP))):boolean> +-- !query 234 output +true + + +-- !query 235 +SELECT timestamp('1996-09-09 12:12:12.4' )= cast(null as string) FROM t +-- !query 235 schema +struct<(CAST(1996-09-09 12:12:12.4 AS TIMESTAMP) = CAST(CAST(NULL AS STRING) AS TIMESTAMP)):boolean> +-- !query 235 output +NULL + + +-- !query 236 +SELECT timestamp('1996-09-09 12:12:12.5' )> cast(null as string) FROM t +-- !query 236 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) > CAST(NULL AS STRING)):boolean> +-- !query 236 output +NULL + + +-- !query 237 +SELECT timestamp('1996-09-09 12:12:12.5' )>= cast(null as string) FROM t +-- !query 237 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) >= CAST(NULL AS STRING)):boolean> +-- !query 237 output +NULL + + +-- !query 238 +SELECT timestamp('1996-09-09 12:12:12.5' )< cast(null as string) FROM t +-- !query 238 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) < CAST(NULL AS STRING)):boolean> +-- !query 238 output +NULL + + +-- !query 239 +SELECT timestamp('1996-09-09 12:12:12.5' )<= cast(null as string) FROM t +-- !query 239 schema +struct<(CAST(CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) AS STRING) <= CAST(NULL AS STRING)):boolean> +-- !query 239 output +NULL + + +-- !query 240 +SELECT timestamp('1996-09-09 12:12:12.5' )<> cast(null as string) FROM t +-- !query 240 schema +struct<(NOT (CAST(1996-09-09 12:12:12.5 AS TIMESTAMP) = CAST(CAST(NULL AS STRING) AS TIMESTAMP))):boolean> +-- !query 240 output +NULL + + +-- !query 241 +SELECT ' ' = X'0020' FROM t +-- !query 241 schema +struct<(CAST( AS BINARY) = X'0020'):boolean> +-- !query 241 output +false + + +-- !query 242 +SELECT ' ' > X'001F' FROM t +-- !query 242 schema +struct<(CAST( AS BINARY) > X'001F'):boolean> +-- !query 242 output +true + + +-- !query 243 +SELECT ' ' >= X'001F' FROM t +-- !query 243 schema +struct<(CAST( AS BINARY) >= X'001F'):boolean> +-- !query 243 output +true + + +-- !query 244 +SELECT ' ' < X'001F' FROM t +-- !query 244 schema +struct<(CAST( AS BINARY) < X'001F'):boolean> +-- !query 244 output +false + + +-- !query 245 +SELECT ' ' <= X'001F' FROM t +-- !query 245 schema +struct<(CAST( AS BINARY) <= X'001F'):boolean> +-- !query 245 output +false + + +-- !query 246 +SELECT ' ' <> X'001F' FROM t +-- !query 246 schema +struct<(NOT (CAST( AS BINARY) = X'001F')):boolean> +-- !query 246 output +true + + +-- !query 247 +SELECT cast(null as string) = X'0020' FROM t +-- !query 247 schema +struct<(CAST(CAST(NULL AS STRING) AS BINARY) = X'0020'):boolean> +-- !query 247 output +NULL + + +-- !query 248 +SELECT cast(null as string) > X'001F' FROM t +-- !query 248 schema +struct<(CAST(CAST(NULL AS STRING) AS BINARY) > X'001F'):boolean> +-- !query 248 output +NULL + + +-- !query 249 +SELECT cast(null as string) >= X'001F' FROM t +-- !query 249 schema +struct<(CAST(CAST(NULL AS STRING) AS BINARY) >= X'001F'):boolean> +-- !query 249 output +NULL + + +-- !query 250 +SELECT cast(null as string) < X'001F' FROM t +-- !query 250 schema +struct<(CAST(CAST(NULL AS STRING) AS BINARY) < X'001F'):boolean> +-- !query 250 output +NULL + + +-- !query 251 +SELECT cast(null as string) <= X'001F' FROM t +-- !query 251 schema +struct<(CAST(CAST(NULL AS STRING) AS BINARY) <= X'001F'):boolean> +-- !query 251 output +NULL + + +-- !query 252 +SELECT cast(null as string) <> X'001F' FROM t +-- !query 252 schema +struct<(NOT (CAST(CAST(NULL AS STRING) AS BINARY) = X'001F')):boolean> +-- !query 252 output +NULL + + +-- !query 253 +SELECT X'0020' = ' ' FROM t +-- !query 253 schema +struct<(X'0020' = CAST( AS BINARY)):boolean> +-- !query 253 output +false + + +-- !query 254 +SELECT X'001F' > ' ' FROM t +-- !query 254 schema +struct<(X'001F' > CAST( AS BINARY)):boolean> +-- !query 254 output +false + + +-- !query 255 +SELECT X'001F' >= ' ' FROM t +-- !query 255 schema +struct<(X'001F' >= CAST( AS BINARY)):boolean> +-- !query 255 output +false + + +-- !query 256 +SELECT X'001F' < ' ' FROM t +-- !query 256 schema +struct<(X'001F' < CAST( AS BINARY)):boolean> +-- !query 256 output +true + + +-- !query 257 +SELECT X'001F' <= ' ' FROM t +-- !query 257 schema +struct<(X'001F' <= CAST( AS BINARY)):boolean> +-- !query 257 output +true + + +-- !query 258 +SELECT X'001F' <> ' ' FROM t +-- !query 258 schema +struct<(NOT (X'001F' = CAST( AS BINARY))):boolean> +-- !query 258 output +true + + +-- !query 259 +SELECT X'0020' = cast(null as string) FROM t +-- !query 259 schema +struct<(X'0020' = CAST(CAST(NULL AS STRING) AS BINARY)):boolean> +-- !query 259 output +NULL + + +-- !query 260 +SELECT X'001F' > cast(null as string) FROM t +-- !query 260 schema +struct<(X'001F' > CAST(CAST(NULL AS STRING) AS BINARY)):boolean> +-- !query 260 output +NULL + + +-- !query 261 +SELECT X'001F' >= cast(null as string) FROM t +-- !query 261 schema +struct<(X'001F' >= CAST(CAST(NULL AS STRING) AS BINARY)):boolean> +-- !query 261 output +NULL + + +-- !query 262 +SELECT X'001F' < cast(null as string) FROM t +-- !query 262 schema +struct<(X'001F' < CAST(CAST(NULL AS STRING) AS BINARY)):boolean> +-- !query 262 output +NULL + + +-- !query 263 +SELECT X'001F' <= cast(null as string) FROM t +-- !query 263 schema +struct<(X'001F' <= CAST(CAST(NULL AS STRING) AS BINARY)):boolean> +-- !query 263 output +NULL + + +-- !query 264 +SELECT X'001F' <> cast(null as string) FROM t +-- !query 264 schema +struct<(NOT (X'001F' = CAST(CAST(NULL AS STRING) AS BINARY))):boolean> +-- !query 264 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out new file mode 100644 index 0000000000000..44fa48e2697b3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/implicitTypeCasts.sql.out @@ -0,0 +1,354 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 44 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT 1 + '2' FROM t +-- !query 1 schema +struct<(CAST(1 AS DOUBLE) + CAST(2 AS DOUBLE)):double> +-- !query 1 output +3.0 + + +-- !query 2 +SELECT 1 - '2' FROM t +-- !query 2 schema +struct<(CAST(1 AS DOUBLE) - CAST(2 AS DOUBLE)):double> +-- !query 2 output +-1.0 + + +-- !query 3 +SELECT 1 * '2' FROM t +-- !query 3 schema +struct<(CAST(1 AS DOUBLE) * CAST(2 AS DOUBLE)):double> +-- !query 3 output +2.0 + + +-- !query 4 +SELECT 4 / '2' FROM t +-- !query 4 schema +struct<(CAST(4 AS DOUBLE) / CAST(CAST(2 AS DOUBLE) AS DOUBLE)):double> +-- !query 4 output +2.0 + + +-- !query 5 +SELECT 1.1 + '2' FROM t +-- !query 5 schema +struct<(CAST(1.1 AS DOUBLE) + CAST(2 AS DOUBLE)):double> +-- !query 5 output +3.1 + + +-- !query 6 +SELECT 1.1 - '2' FROM t +-- !query 6 schema +struct<(CAST(1.1 AS DOUBLE) - CAST(2 AS DOUBLE)):double> +-- !query 6 output +-0.8999999999999999 + + +-- !query 7 +SELECT 1.1 * '2' FROM t +-- !query 7 schema +struct<(CAST(1.1 AS DOUBLE) * CAST(2 AS DOUBLE)):double> +-- !query 7 output +2.2 + + +-- !query 8 +SELECT 4.4 / '2' FROM t +-- !query 8 schema +struct<(CAST(4.4 AS DOUBLE) / CAST(2 AS DOUBLE)):double> +-- !query 8 output +2.2 + + +-- !query 9 +SELECT 1.1 + '2.2' FROM t +-- !query 9 schema +struct<(CAST(1.1 AS DOUBLE) + CAST(2.2 AS DOUBLE)):double> +-- !query 9 output +3.3000000000000003 + + +-- !query 10 +SELECT 1.1 - '2.2' FROM t +-- !query 10 schema +struct<(CAST(1.1 AS DOUBLE) - CAST(2.2 AS DOUBLE)):double> +-- !query 10 output +-1.1 + + +-- !query 11 +SELECT 1.1 * '2.2' FROM t +-- !query 11 schema +struct<(CAST(1.1 AS DOUBLE) * CAST(2.2 AS DOUBLE)):double> +-- !query 11 output +2.4200000000000004 + + +-- !query 12 +SELECT 4.4 / '2.2' FROM t +-- !query 12 schema +struct<(CAST(4.4 AS DOUBLE) / CAST(2.2 AS DOUBLE)):double> +-- !query 12 output +2.0 + + +-- !query 13 +SELECT '$' || cast(1 as smallint) || '$' FROM t +-- !query 13 schema +struct +-- !query 13 output +$1$ + + +-- !query 14 +SELECT '$' || 1 || '$' FROM t +-- !query 14 schema +struct +-- !query 14 output +$1$ + + +-- !query 15 +SELECT '$' || cast(1 as bigint) || '$' FROM t +-- !query 15 schema +struct +-- !query 15 output +$1$ + + +-- !query 16 +SELECT '$' || cast(1.1 as float) || '$' FROM t +-- !query 16 schema +struct +-- !query 16 output +$1.1$ + + +-- !query 17 +SELECT '$' || cast(1.1 as double) || '$' FROM t +-- !query 17 schema +struct +-- !query 17 output +$1.1$ + + +-- !query 18 +SELECT '$' || 1.1 || '$' FROM t +-- !query 18 schema +struct +-- !query 18 output +$1.1$ + + +-- !query 19 +SELECT '$' || cast(1.1 as decimal(8,3)) || '$' FROM t +-- !query 19 schema +struct +-- !query 19 output +$1.100$ + + +-- !query 20 +SELECT '$' || 'abcd' || '$' FROM t +-- !query 20 schema +struct +-- !query 20 output +$abcd$ + + +-- !query 21 +SELECT '$' || date('1996-09-09') || '$' FROM t +-- !query 21 schema +struct +-- !query 21 output +$1996-09-09$ + + +-- !query 22 +SELECT '$' || timestamp('1996-09-09 10:11:12.4' )|| '$' FROM t +-- !query 22 schema +struct +-- !query 22 output +$1996-09-09 10:11:12.4$ + + +-- !query 23 +SELECT length(cast(1 as smallint)) FROM t +-- !query 23 schema +struct +-- !query 23 output +1 + + +-- !query 24 +SELECT length(cast(1 as int)) FROM t +-- !query 24 schema +struct +-- !query 24 output +1 + + +-- !query 25 +SELECT length(cast(1 as bigint)) FROM t +-- !query 25 schema +struct +-- !query 25 output +1 + + +-- !query 26 +SELECT length(cast(1.1 as float)) FROM t +-- !query 26 schema +struct +-- !query 26 output +3 + + +-- !query 27 +SELECT length(cast(1.1 as double)) FROM t +-- !query 27 schema +struct +-- !query 27 output +3 + + +-- !query 28 +SELECT length(1.1) FROM t +-- !query 28 schema +struct +-- !query 28 output +3 + + +-- !query 29 +SELECT length(cast(1.1 as decimal(8,3))) FROM t +-- !query 29 schema +struct +-- !query 29 output +5 + + +-- !query 30 +SELECT length('four') FROM t +-- !query 30 schema +struct +-- !query 30 output +4 + + +-- !query 31 +SELECT length(date('1996-09-10')) FROM t +-- !query 31 schema +struct +-- !query 31 output +10 + + +-- !query 32 +SELECT length(timestamp('1996-09-10 10:11:12.4')) FROM t +-- !query 32 schema +struct +-- !query 32 output +21 + + +-- !query 33 +SELECT year( '1996-01-10') FROM t +-- !query 33 schema +struct +-- !query 33 output +1996 + + +-- !query 34 +SELECT month( '1996-01-10') FROM t +-- !query 34 schema +struct +-- !query 34 output +1 + + +-- !query 35 +SELECT day( '1996-01-10') FROM t +-- !query 35 schema +struct +-- !query 35 output +10 + + +-- !query 36 +SELECT hour( '10:11:12') FROM t +-- !query 36 schema +struct +-- !query 36 output +10 + + +-- !query 37 +SELECT minute( '10:11:12') FROM t +-- !query 37 schema +struct +-- !query 37 output +11 + + +-- !query 38 +SELECT second( '10:11:12') FROM t +-- !query 38 schema +struct +-- !query 38 output +12 + + +-- !query 39 +select 1 like '%' FROM t +-- !query 39 schema +struct +-- !query 39 output +true + + +-- !query 40 +select date('1996-09-10') like '19%' FROM t +-- !query 40 schema +struct +-- !query 40 output +true + + +-- !query 41 +select '1' like 1 FROM t +-- !query 41 schema +struct<1 LIKE CAST(1 AS STRING):boolean> +-- !query 41 output +true + + +-- !query 42 +select '1 ' like 1 FROM t +-- !query 42 schema +struct<1 LIKE CAST(1 AS STRING):boolean> +-- !query 42 output +false + + +-- !query 43 +select '1996-09-10' like date('1996-09-10') FROM t +-- !query 43 schema +struct<1996-09-10 LIKE CAST(CAST(1996-09-10 AS DATE) AS STRING):boolean> +-- !query 43 output +true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index cac3e12dde3e1..e3901af4b9988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -300,7 +300,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { Locale.setDefault(originalLocale) // For debugging dump some statistics about how much time was spent in various optimizer rules - logInfo(RuleExecutor.dumpTimeSpent()) + logWarning(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index dbea03689785c..a58000da1543d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -41,7 +41,7 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte protected override def afterAll(): Unit = { try { // For debugging dump some statistics about how much time was spent in various optimizer rules - logInfo(RuleExecutor.dumpTimeSpent()) + logWarning(RuleExecutor.dumpTimeSpent()) spark.sessionState.catalog.reset() } finally { super.afterAll() diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index def70a516a96b..45791c69b4cb7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules - logInfo(RuleExecutor.dumpTimeSpent()) + logWarning(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } From a4002651a3ea673cf3eff7927531c1659663d194 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 11 Dec 2017 16:33:06 -0800 Subject: [PATCH 394/712] [SPARK-20557][SQL] Only support TIMESTAMP WITH TIME ZONE for Oracle Dialect ## What changes were proposed in this pull request? In the previous PRs, https://github.com/apache/spark/pull/17832 and https://github.com/apache/spark/pull/17835 , we convert `TIMESTAMP WITH TIME ZONE` and `TIME WITH TIME ZONE` to `TIMESTAMP` for all the JDBC sources. However, this conversion could be risky since it does not respect our SQL configuration `spark.sql.session.timeZone`. In addition, each vendor might have different semantics for these two types. For example, Postgres simply returns `TIMESTAMP` types for `TIMESTAMP WITH TIME ZONE`. For such supports, we should do it case by case. This PR reverts the general support of `TIMESTAMP WITH TIME ZONE` and `TIME WITH TIME ZONE` for JDBC sources, except ORACLE Dialect. When supporting the ORACLE's `TIMESTAMP WITH TIME ZONE`, we only support it when the JVM default timezone is the same as the user-specified configuration `spark.sql.session.timeZone` (whose default is the JVM default timezone). Now, we still treat `TIMESTAMP WITH TIME ZONE` as `TIMESTAMP` when fetching the values via the Oracle JDBC connector, whose client converts the timestamp values with time zone to the timestamp values using the local JVM default timezone (a test case is added to `OracleIntegrationSuite.scala` in this PR for showing the behavior). Thus, to avoid any future behavior change, we will not support it if JVM default timezone is different from `spark.sql.session.timeZone` No regression because the previous two PRs were just merged to be unreleased master branch. ## How was this patch tested? Added the test cases Author: gatorsmile Closes #19939 from gatorsmile/timezoneUpdate. --- .../sql/jdbc/OracleIntegrationSuite.scala | 67 ++++++++++++++++++- .../sql/jdbc/PostgresIntegrationSuite.scala | 2 + .../datasources/jdbc/JdbcUtils.scala | 4 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 13 +++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +- 5 files changed, 82 insertions(+), 8 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 90343182712ed..8512496e5fe52 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Timestamp} -import java.util.Properties +import java.util.{Properties, TimeZone} import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row, SaveMode} -import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -77,6 +78,9 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo conn.prepareStatement( "INSERT INTO ts_with_timezone VALUES " + "(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate() + conn.prepareStatement( + "INSERT INTO ts_with_timezone VALUES " + + "(2, to_timestamp_tz('1999-12-01 12:00:00 PST','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate() conn.commit() conn.prepareStatement( @@ -235,6 +239,63 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(types(1).equals("class java.sql.Timestamp")) } + test("Column type TIMESTAMP with SESSION_LOCAL_TIMEZONE is different from default") { + val defaultJVMTimeZone = TimeZone.getDefault + // Pick the timezone different from the current default time zone of JVM + val sofiaTimeZone = TimeZone.getTimeZone("Europe/Sofia") + val shanghaiTimeZone = TimeZone.getTimeZone("Asia/Shanghai") + val localSessionTimeZone = + if (defaultJVMTimeZone == shanghaiTimeZone) sofiaTimeZone else shanghaiTimeZone + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> localSessionTimeZone.getID) { + val e = intercept[java.sql.SQLException] { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + dfRead.collect() + }.getMessage + assert(e.contains("Unrecognized SQL type -101")) + } + } + + /** + * Change the Time Zone `timeZoneId` of JVM before executing `f`, then switches back to the + * original after `f` returns. + * @param timeZoneId the ID for a TimeZone, either an abbreviation such as "PST", a full name such + * as "America/Los_Angeles", or a custom ID such as "GMT-8:00". + */ + private def withTimeZone(timeZoneId: String)(f: => Unit): Unit = { + val originalLocale = TimeZone.getDefault + try { + // Add Locale setting + TimeZone.setDefault(TimeZone.getTimeZone(timeZoneId)) + f + } finally { + TimeZone.setDefault(originalLocale) + } + } + + test("Column TIMESTAMP with TIME ZONE(JVM timezone)") { + def checkRow(row: Row, ts: String): Unit = { + assert(row.getTimestamp(1).equals(Timestamp.valueOf(ts))) + } + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> TimeZone.getDefault.getID) { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + withTimeZone("PST") { + assert(dfRead.collect().toSet === + Set( + Row(BigDecimal.valueOf(1), java.sql.Timestamp.valueOf("1999-12-01 03:00:00")), + Row(BigDecimal.valueOf(2), java.sql.Timestamp.valueOf("1999-12-01 12:00:00")))) + } + + withTimeZone("UTC") { + assert(dfRead.collect().toSet === + Set( + Row(BigDecimal.valueOf(1), java.sql.Timestamp.valueOf("1999-12-01 11:00:00")), + Row(BigDecimal.valueOf(2), java.sql.Timestamp.valueOf("1999-12-01 20:00:00")))) + } + } + } + test("SPARK-18004: Make sure date or timestamp related predicate is pushed down correctly") { val props = new Properties() props.put("oracle.jdbc.mapDateToTimestamp", "false") diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 48aba90afc787..be32cb89f4886 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -151,6 +151,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE " + "should be recognized") { + // When using JDBC to read the columns of TIMESTAMP with TIME ZONE and TIME with TIME ZONE + // the actual types are java.sql.Types.TIMESTAMP and java.sql.Types.TIME val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 75c94fc486493..bbc95df4d9dc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -226,10 +226,10 @@ object JdbcUtils extends Logging { case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType case java.sql.Types.TIME_WITH_TIMEZONE - => TimestampType + => null case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE - => TimestampType + => null case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index e3f106c41c7ff..6ef77f24460be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.jdbc import java.sql.{Date, Timestamp, Types} +import java.util.TimeZone +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -29,6 +32,13 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + private def supportTimeZoneTypes: Boolean = { + val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) + // TODO: support timezone types when users are not using the JVM timezone, which + // is the default value of SESSION_LOCAL_TIMEZONE + timeZone == TimeZone.getDefault + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { sqlType match { @@ -49,7 +59,8 @@ private case object OracleDialect extends JdbcDialect { case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) case _ => None } - case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case TIMESTAMPTZ if supportTimeZoneTypes + => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE case _ => None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 61571bccdcb51..0767ca1573a7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1064,10 +1064,10 @@ class JDBCSuite extends SparkFunSuite } test("unsupported types") { - var e = intercept[SparkException] { + var e = intercept[SQLException] { spark.read.jdbc(urlWithUserAndPass, "TEST.TIMEZONE", new Properties()).collect() }.getMessage - assert(e.contains("java.lang.UnsupportedOperationException: unimplemented")) + assert(e.contains("Unsupported type TIMESTAMP_WITH_TIMEZONE")) e = intercept[SQLException] { spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() }.getMessage From ecc179ecaa9b4381d2f6d3356d88a4dff3e19c0f Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Tue, 12 Dec 2017 15:04:49 +0800 Subject: [PATCH 395/712] [SPARK-21322][SQL] support histogram in filter cardinality estimation ## What changes were proposed in this pull request? Histogram is effective in dealing with skewed distribution. After we generate histogram information for column statistics, we need to adjust filter estimation based on histogram data structure. ## How was this patch tested? We revised all the unit test cases by including histogram data structure. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ron Hu Closes #19783 from ron8hu/supportHistogram. --- .../statsEstimation/EstimationUtils.scala | 97 +++++++- .../statsEstimation/FilterEstimation.scala | 181 +++++++++++---- .../FilterEstimationSuite.scala | 217 +++++++++++++++++- 3 files changed, 448 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 9c34a9b7aa756..2f416f2af91f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} @@ -114,4 +114,99 @@ object EstimationUtils { } } + /** + * Returns the number of the first bin into which a column value falls for a specified + * numeric equi-height histogram. + * + * @param value a literal value of a column + * @param bins an array of bins for a given numeric equi-height histogram + * @return the id of the first bin into which a column value falls. + */ + def findFirstBinForValue(value: Double, bins: Array[HistogramBin]): Int = { + var i = 0 + while ((i < bins.length) && (value > bins(i).hi)) { + i += 1 + } + i + } + + /** + * Returns the number of the last bin into which a column value falls for a specified + * numeric equi-height histogram. + * + * @param value a literal value of a column + * @param bins an array of bins for a given numeric equi-height histogram + * @return the id of the last bin into which a column value falls. + */ + def findLastBinForValue(value: Double, bins: Array[HistogramBin]): Int = { + var i = bins.length - 1 + while ((i >= 0) && (value < bins(i).lo)) { + i -= 1 + } + i + } + + /** + * Returns a percentage of a bin holding values for column value in the range of + * [lowerValue, higherValue] + * + * @param higherValue a given upper bound value of a specified column value range + * @param lowerValue a given lower bound value of a specified column value range + * @param bin a single histogram bin + * @return the percentage of a single bin holding values in [lowerValue, higherValue]. + */ + private def getOccupation( + higherValue: Double, + lowerValue: Double, + bin: HistogramBin): Double = { + assert(bin.lo <= lowerValue && lowerValue <= higherValue && higherValue <= bin.hi) + if (bin.hi == bin.lo) { + // the entire bin is covered in the range + 1.0 + } else if (higherValue == lowerValue) { + // set percentage to 1/NDV + 1.0 / bin.ndv.toDouble + } else { + // Use proration since the range falls inside this bin. + math.min((higherValue - lowerValue) / (bin.hi - bin.lo), 1.0) + } + } + + /** + * Returns the number of bins for column values in [lowerValue, higherValue]. + * The column value distribution is saved in an equi-height histogram. The return values is a + * double value is because we may return a portion of a bin. For example, a predicate + * "column = 8" may return the number of bins 0.2 if the holding bin has 5 distinct values. + * + * @param higherId id of the high end bin holding the high end value of a column range + * @param lowerId id of the low end bin holding the low end value of a column range + * @param higherEnd a given upper bound value of a specified column value range + * @param lowerEnd a given lower bound value of a specified column value range + * @param histogram a numeric equi-height histogram + * @return the number of bins for column values in [lowerEnd, higherEnd]. + */ + def getOccupationBins( + higherId: Int, + lowerId: Int, + higherEnd: Double, + lowerEnd: Double, + histogram: Histogram): Double = { + assert(lowerId <= higherId) + + if (lowerId == higherId) { + val curBin = histogram.bins(lowerId) + getOccupation(higherEnd, lowerEnd, curBin) + } else { + // compute how much lowerEnd/higherEnd occupies its bin + val lowerCurBin = histogram.bins(lowerId) + val lowerPart = getOccupation(lowerCurBin.hi, lowerEnd, lowerCurBin) + + val higherCurBin = histogram.bins(higherId) + val higherPart = getOccupation(higherEnd, higherCurBin.lo, higherCurBin) + + // the total length is lowerPart + higherPart + bins between them + lowerPart + higherPart + higherId - lowerId - 1 + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 74820eb97d081..f52a15edbfe4a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -265,7 +265,7 @@ case class FilterEstimation(plan: Filter) extends Logging { * @param update a boolean flag to specify if we need to update ColumnStat of a given column * for subsequent conditions * @return an optional double value to show the percentage of rows meeting a given condition - * It returns None if no statistics exists for a given column or wrong value. + * It returns None if no statistics exists for a given column or wrong value. */ def evaluateBinary( op: BinaryComparison, @@ -332,8 +332,44 @@ case class FilterEstimation(plan: Filter) extends Logging { colStatsMap.update(attr, newStats) } - Some(1.0 / BigDecimal(ndv)) - } else { + if (colStat.histogram.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / BigDecimal(ndv)) + } else { + // We compute filter selectivity using Histogram information. + val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val histogram = colStat.histogram.get + val hgmBins = histogram.bins + + // find bins where column's current min and max locate. Note that a column's [min, max] + // range may change due to another condition applied earlier. + val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble + val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val minBinId = EstimationUtils.findFirstBinForValue(min, hgmBins) + val maxBinId = EstimationUtils.findLastBinForValue(max, hgmBins) + + // compute how many bins the column's current valid range [min, max] occupies. + // Note that a column's [min, max] range may vary after we apply some filter conditions. + val validRangeBins = EstimationUtils.getOccupationBins(maxBinId, minBinId, max, + min, histogram) + + val lowerBinId = EstimationUtils.findFirstBinForValue(datum, hgmBins) + val higherBinId = EstimationUtils.findLastBinForValue(datum, hgmBins) + assert(lowerBinId <= higherBinId) + val lowerBinNdv = hgmBins(lowerBinId).ndv + val higherBinNdv = hgmBins(higherBinId).ndv + // assume uniform distribution in each bin + val occupiedBins = if (lowerBinId == higherBinId) { + 1.0 / lowerBinNdv + } else { + (1.0 / lowerBinNdv) + // lowest bin + (higherBinId - lowerBinId - 1) + // middle bins + (1.0 / higherBinNdv) // highest bin + } + Some(occupiedBins / validRangeBins) + } + + } else { // not in interval Some(0.0) } @@ -471,37 +507,46 @@ case class FilterEstimation(plan: Filter) extends Logging { percent = 1.0 } else { // This is the partial overlap case: - // Without advanced statistics like histogram, we assume uniform data distribution. - // We just prorate the adjusted range over the initial range to compute filter selectivity. - assert(max > min) - percent = op match { - case _: LessThan => - if (numericLiteral == max) { - // If the literal value is right on the boundary, we can minus the part of the - // boundary value (1/ndv). - 1.0 - 1.0 / ndv - } else { - (numericLiteral - min) / (max - min) - } - case _: LessThanOrEqual => - if (numericLiteral == min) { - // The boundary value is the only satisfying value. - 1.0 / ndv - } else { - (numericLiteral - min) / (max - min) - } - case _: GreaterThan => - if (numericLiteral == min) { - 1.0 - 1.0 / ndv - } else { - (max - numericLiteral) / (max - min) - } - case _: GreaterThanOrEqual => - if (numericLiteral == max) { - 1.0 / ndv - } else { - (max - numericLiteral) / (max - min) - } + + if (colStat.histogram.isEmpty) { + // Without advanced statistics like histogram, we assume uniform data distribution. + // We just prorate the adjusted range over the initial range to compute filter selectivity. + assert(max > min) + percent = op match { + case _: LessThan => + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } + case _: LessThanOrEqual => + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } + case _: GreaterThan => + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } + case _: GreaterThanOrEqual => + if (numericLiteral == max) { + 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } + } + } else { + val numericHistogram = colStat.histogram.get + val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble + percent = computePercentByEquiHeightHgm(op, numericHistogram, max, min, datum) } if (update) { @@ -513,10 +558,9 @@ case class FilterEstimation(plan: Filter) extends Logging { op match { case _: GreaterThan | _: GreaterThanOrEqual => - // If new ndv is 1, then new max must be equal to new min. - newMin = if (newNdv == 1) newMax else newValue + newMin = newValue case _: LessThan | _: LessThanOrEqual => - newMax = if (newNdv == 1) newMin else newValue + newMax = newValue } val newStats = @@ -529,6 +573,54 @@ case class FilterEstimation(plan: Filter) extends Logging { Some(percent) } + /** + * Returns the selectivity percentage for binary condition in the column's + * current valid range [min, max] + * + * @param op a binary comparison operator + * @param histogram a numeric equi-height histogram + * @param max the upper bound of the current valid range for a given column + * @param min the lower bound of the current valid range for a given column + * @param datumNumber the numeric value of a literal + * @return the selectivity percentage for a condition in the current range. + */ + + def computePercentByEquiHeightHgm( + op: BinaryComparison, + histogram: Histogram, + max: Double, + min: Double, + datumNumber: Double): Double = { + // find bins where column's current min and max locate. Note that a column's [min, max] + // range may change due to another condition applied earlier. + val minBinId = EstimationUtils.findFirstBinForValue(min, histogram.bins) + val maxBinId = EstimationUtils.findLastBinForValue(max, histogram.bins) + + // compute how many bins the column's current valid range [min, max] occupies. + // Note that a column's [min, max] range may vary after we apply some filter conditions. + val minToMaxLength = EstimationUtils.getOccupationBins(maxBinId, minBinId, max, min, histogram) + + val datumInBinId = op match { + case LessThan(_, _) | GreaterThanOrEqual(_, _) => + EstimationUtils.findFirstBinForValue(datumNumber, histogram.bins) + case LessThanOrEqual(_, _) | GreaterThan(_, _) => + EstimationUtils.findLastBinForValue(datumNumber, histogram.bins) + } + + op match { + // LessThan and LessThanOrEqual share the same logic, + // but their datumInBinId may be different + case LessThan(_, _) | LessThanOrEqual(_, _) => + EstimationUtils.getOccupationBins(datumInBinId, minBinId, datumNumber, min, + histogram) / minToMaxLength + // GreaterThan and GreaterThanOrEqual share the same logic, + // but their datumInBinId may be different + case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + EstimationUtils.getOccupationBins(maxBinId, datumInBinId, max, datumNumber, + histogram) / minToMaxLength + } + } + /** * Returns a percentage of rows meeting a binary comparison expression containing two columns. * In SQL queries, we also see predicate expressions involving two columns @@ -784,11 +876,16 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt) : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => - // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows - // decreases; otherwise keep it unchanged. - val newNdv = EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) + val newNdv = if (colStat.distinctCount > 1) { + // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows + // decreases; otherwise keep it unchanged. + EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + } else { + // no need to scale down since it is already down to 1 (for skewed distribution case) + colStat.distinctCount + } attr -> colStat.copy(distinctCount = newNdv) } AttributeMap(newColumnStats.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 455037e6c9952..2b1fe987a7960 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ */ class FilterEstimationSuite extends StatsEstimationTestBase { - // Suppose our test table has 10 rows and 6 columns. + // Suppose our test table has 10 rows and 10 columns. // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() @@ -91,6 +91,26 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) + // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. + // Note that cintHgm has an even distribution with histogram information built. + // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val attrIntHgm = AttributeReference("cintHgm", IntegerType)() + val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), + HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), + HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) + val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + + // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. + // Note that cintSkewHgm has a skewed distribution with histogram information built. + // distinctCount:5, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + val attrIntSkewHgm = AttributeReference("cintSkewHgm", IntegerType)() + val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), + HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), + HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) + val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, attrBool -> colStatBool, @@ -100,7 +120,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attrString -> colStatString, attrInt2 -> colStatInt2, attrInt3 -> colStatInt3, - attrInt4 -> colStatInt4 + attrInt4 -> colStatInt4, + attrIntHgm -> colStatIntHgm, + attrIntSkewHgm -> colStatIntSkewHgm )) test("true") { @@ -359,7 +381,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1)), expectedRowCount = 5) } @@ -578,6 +600,193 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 5) } + // The following test cases have histogram information collected for the test column with + // an even distribution + test("Not(cintHgm < 3 AND null)") { + val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + expectedRowCount = 7) + } + + test("cintHgm = 5") { + validateEstimatedStats( + Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 1) + } + + test("cintHgm = 0") { + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + Filter(EqualTo(attrIntHgm, Literal(0)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintHgm < 3") { + validateEstimatedStats( + Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 3) + } + + test("cintHgm < 0") { + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + Filter(LessThan(attrIntHgm, Literal(0)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintHgm <= 3") { + validateEstimatedStats( + Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 3) + } + + test("cintHgm > 6") { + validateEstimatedStats( + Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 4) + } + + test("cintHgm > 10") { + // This is a corner case since max value is 10. + validateEstimatedStats( + Filter(GreaterThan(attrIntHgm, Literal(10)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintHgm >= 6") { + validateEstimatedStats( + Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 5) + } + + test("cintHgm > 3 AND cintHgm <= 6") { + val condition = And(GreaterThan(attrIntHgm, + Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + expectedRowCount = 4) + } + + test("cintHgm = 3 OR cintHgm = 6") { + val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + expectedRowCount = 3) + } + + // The following test cases have histogram information collected for the test column with + // a skewed distribution. + test("Not(cintSkewHgm < 3 AND null)") { + val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + expectedRowCount = 9) + } + + test("cintSkewHgm = 5") { + validateEstimatedStats( + Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 4) + } + + test("cintSkewHgm = 0") { + // This is an out-of-range case since 0 is outside the range [min, max] + validateEstimatedStats( + Filter(EqualTo(attrIntSkewHgm, Literal(0)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintSkewHgm < 3") { + validateEstimatedStats( + Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 2) + } + + test("cintSkewHgm < 0") { + // This is a corner case since literal 0 is smaller than min. + validateEstimatedStats( + Filter(LessThan(attrIntSkewHgm, Literal(0)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintSkewHgm <= 3") { + validateEstimatedStats( + Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), + childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 2) + } + + test("cintSkewHgm > 6") { + validateEstimatedStats( + Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 2) + } + + test("cintSkewHgm > 10") { + // This is a corner case since max value is 10. + validateEstimatedStats( + Filter(GreaterThan(attrIntSkewHgm, Literal(10)), + childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cintSkewHgm >= 6") { + validateEstimatedStats( + Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), + childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 3) + } + + test("cintSkewHgm > 3 AND cintSkewHgm <= 6") { + val condition = And(GreaterThan(attrIntSkewHgm, + Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + expectedRowCount = 8) + } + + test("cintSkewHgm = 3 OR cintSkewHgm = 6") { + val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + expectedRowCount = 3) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, From bc8933faf238bcf14e7976bd1ac1465dc32b8e2b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 12 Dec 2017 17:02:04 +0900 Subject: [PATCH 396/712] [SPARK-3685][CORE] Prints explicit warnings when configured local directories are set to URIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR proposes to print warnings before creating local by `java.io.File`. I think we can't just simply disallow and throw an exception for such cases of `hdfs:/tmp/foo` case because it might break compatibility. Note that `hdfs:/tmp/foo` creates a directory called `hdfs:/`. There were many discussion here about whether we should support this in other file systems or not; however, since the JIRA targets "Spark's local dir should accept only local paths", here I tried to resolve it by simply printing warnings. I think we could open another JIRA and design doc if this is something we should support, separately. Another note, for your information, [SPARK-1529](https://issues.apache.org/jira/browse/SPARK-1529) is resolved as `Won't Fix`. **Before** ``` ./bin/spark-shell --conf spark.local.dir=file:/a/b/c ``` This creates a local directory as below: ``` file:/ └── a └── b └── c ... ``` **After** ```bash ./bin/spark-shell --conf spark.local.dir=file:/a/b/c ``` Now, it prints a warning as below: ``` ... 17/12/09 21:58:49 WARN Utils: The configured local directories are not expected to be URIs; however, got suspicious values [file:/a/b/c]. Please check your configured local directories. ... ``` ```bash ./bin/spark-shell --conf spark.local.dir=file:/a/b/c,/tmp/a/b/c,hdfs:/a/b/c ``` It also works with comma-separated ones: ``` ... 17/12/09 22:05:01 WARN Utils: The configured local directories are not expected to be URIs; however, got suspicious values [file:/a/b/c, hdfs:/a/b/c]. Please check your configured local directories. ... ``` ## How was this patch tested? Manually tested: ``` ./bin/spark-shell --conf spark.local.dir=C:\\a\\b\\c ./bin/spark-shell --conf spark.local.dir=/tmp/a/b/c ./bin/spark-shell --conf spark.local.dir=a/b/c ./bin/spark-shell --conf spark.local.dir=a/b/c,/tmp/a/b/c,C:\\a\\b\\c ``` Author: hyukjinkwon Closes #19934 from HyukjinKwon/SPARK-3685. --- .../main/scala/org/apache/spark/util/Utils.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1ed09dc489440..fe5b4ea24440b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -829,7 +829,18 @@ private[spark] object Utils extends Logging { } private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { - getConfiguredLocalDirs(conf).flatMap { root => + val configuredLocalDirs = getConfiguredLocalDirs(conf) + val uris = configuredLocalDirs.filter { root => + // Here, we guess if the given value is a URI at its best - check if scheme is set. + Try(new URI(root).getScheme != null).getOrElse(false) + } + if (uris.nonEmpty) { + logWarning( + "The configured local directories are not expected to be URIs; however, got suspicious " + + s"values [${uris.mkString(", ")}]. Please check your configured local directories.") + } + + configuredLocalDirs.flatMap { root => try { val rootDir = new File(root) if (rootDir.exists || rootDir.mkdirs()) { From d5007734b22fbee5eccb1682eef13479780423ab Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 12 Dec 2017 10:07:18 -0800 Subject: [PATCH 397/712] [SPARK-16986][WEB-UI] Converter Started, Completed and Last Updated to client time zone in history page ## What changes were proposed in this pull request? This PR is converted the ` Started`, `Completed` and `Last Updated` to client local time in the history page. ## How was this patch tested? Manual tests for Chrome, Firefox and Safari #### Before modifying: before-webui #### After modifying: after Author: Yuming Wang Closes #19640 from wangyum/SPARK-16986. --- .../spark/ui/static/historypage-common.js | 11 ++++---- .../org/apache/spark/ui/static/historypage.js | 11 ++------ .../org/apache/spark/ui/static/utils.js | 28 +++++++++++++++++++ .../spark/deploy/history/HistoryPage.scala | 4 +++ .../deploy/history/HistoryServerSuite.scala | 2 +- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js index 55d540d8317a0..4cfe46ec914ae 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -16,9 +16,10 @@ */ $(document).ready(function() { - if ($('#last-updated').length) { - var lastUpdatedMillis = Number($('#last-updated').text()); - var updatedDate = new Date(lastUpdatedMillis); - $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) - } + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + $('#last-updated').text(formatTimeMillis(lastUpdatedMillis)); + } + + $('#time-zone').text(getTimeZone()); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index aa7e860372553..2cde66b081a1c 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -37,11 +37,6 @@ function makeIdNumeric(id) { return resl; } -function formatDate(date) { - if (date <= 0) return "-"; - else return date.split(".")[0].replace("T", " "); -} - function getParameterByName(name, searchString) { var regex = new RegExp("[\\?&]" + name + "=([^&#]*)"), results = regex.exec(searchString); @@ -129,9 +124,9 @@ $(document).ready(function() { var num = app["attempts"].length; for (j in app["attempts"]) { var attempt = app["attempts"][j]; - attempt["startTime"] = formatDate(attempt["startTime"]); - attempt["endTime"] = formatDate(attempt["endTime"]); - attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["startTime"] = formatTimeMillis(attempt["startTimeEpoch"]); + attempt["endTime"] = formatTimeMillis(attempt["endTimeEpoch"]); + attempt["lastUpdated"] = formatTimeMillis(attempt["lastUpdatedEpoch"]); attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; attempt["durationMillisec"] = attempt["duration"]; diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js index edc0ee2ce181d..4f63f6413d6de 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/utils.js +++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js @@ -46,3 +46,31 @@ function formatBytes(bytes, type) { var i = Math.floor(Math.log(bytes) / Math.log(k)); return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i]; } + +function padZeroes(num) { + return ("0" + num).slice(-2); +} + +function formatTimeMillis(timeMillis) { + if (timeMillis <= 0) { + return "-"; + } else { + var dt = new Date(timeMillis); + return dt.getFullYear() + "-" + + padZeroes(dt.getMonth() + 1) + "-" + + padZeroes(dt.getDate()) + " " + + padZeroes(dt.getHours()) + ":" + + padZeroes(dt.getMinutes()) + ":" + + padZeroes(dt.getSeconds()); + } +} + +function getTimeZone() { + try { + return Intl.DateTimeFormat().resolvedOptions().timeZone; + } catch(ex) { + // Get time zone from a string representing the date, + // eg. "Thu Nov 16 2017 01:13:32 GMT+0800 (CST)" -> "CST" + return new Date().toString().match(/\((.*)\)/)[1]; + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d3dd58996a0bb..5d62a7d8bebb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -57,6 +57,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } } + { +

    Client local time zone:

    + } + { if (allAppsSize > 0) { ++ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4e4395d0fb959..3a9790cd57270 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -345,7 +345,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .map(_.get) .filter(_.startsWith(url)).toList - // there are atleast some URL links that were generated via javascript, + // there are at least some URL links that were generated via javascript, // and they all contain the spark.ui.proxyBase (uiRoot) links.length should be > 4 all(links) should startWith(url + uiRoot) From e6dc5f280786bdb068abb7348b787324f76c0e4a Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Tue, 12 Dec 2017 10:41:37 -0800 Subject: [PATCH 398/712] [SPARK-22729][SQL] Add getTruncateQuery to JdbcDialect In order to enable truncate for PostgreSQL databases in Spark JDBC, a change is needed to the query used for truncating a PostgreSQL table. By default, PostgreSQL will automatically truncate any descendant tables if a TRUNCATE query is executed. As this may result in (unwanted) side-effects, the query used for the truncate should be specified separately for PostgreSQL, specifying only to TRUNCATE a single table. ## What changes were proposed in this pull request? Add `getTruncateQuery` function to `JdbcDialect.scala`, with default query. Overridden this function for PostgreSQL to only truncate a single table. Also sets `isCascadingTruncateTable` to false, as this will allow truncates for PostgreSQL. ## How was this patch tested? Existing tests all pass. Added test for `getTruncateQuery` Author: Daniel van der Ende Closes #19911 from danielvdende/SPARK-22717. --- .../datasources/jdbc/JdbcRelationProvider.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 7 ++++--- .../spark/sql/jdbc/AggregatedDialect.scala | 6 +++++- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 12 ++++++++++++ .../apache/spark/sql/jdbc/PostgresDialect.scala | 14 ++++++++++++-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++ 6 files changed, 50 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 37e7bb0a59bb6..cc506e51bd0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.Overwrite => if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. - truncateTable(conn, options.table) + truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index bbc95df4d9dc4..e6dc2fda4eb1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -96,12 +96,13 @@ object JdbcUtils extends Logging { } /** - * Truncates a table from the JDBC database. + * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, table: String): Unit = { + def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(s"TRUNCATE TABLE $table") + statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index f3bfea5f6bfc8..8b92c8b4f56b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** * AggregatedDialect can unify multiple dialects into one virtual Dialect. * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. + * neutral element will win. * * @param dialects List of dialects. */ @@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect case _ => Some(false) } } + + override def getTruncateQuery(table: String): String = { + dialects.head.getTruncateQuery(table) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7c38ed68c0413..83d87a11810c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -116,6 +116,18 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + @Since("2.3.0") + def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE $table" + } + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4f61a328f47ca..13a2035f4d0c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + /** + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE ONLY $table" + } + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) @@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect { if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(true) + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0767ca1573a7f..cb2df0ac54f4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -854,6 +854,22 @@ class JDBCSuite extends SparkFunSuite assert(derby.getTableExistsQuery(table) == defaultQuery) } + test("truncate table query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table" + assert(MySQL.getTruncateQuery(table) == defaultQuery) + assert(Postgres.getTruncateQuery(table) == postgresQuery) + assert(db2.getTruncateQuery(table) == defaultQuery) + assert(h2.getTruncateQuery(table) == defaultQuery) + assert(derby.getTruncateQuery(table) == defaultQuery) + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); From 10c27a6559803797e89c28ced11c1087127b82eb Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 12 Dec 2017 11:27:01 -0800 Subject: [PATCH 399/712] [SPARK-22289][ML] Add JSON support for Matrix parameters (LR with coefficients bound) ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-22289 add JSON encoding/decoding for Param[Matrix]. The issue was reported by Nic Eggert during saving LR model with LowerBoundsOnCoefficients. There're two ways to resolve this as I see: 1. Support save/load on LogisticRegressionParams, and also adjust the save/load in LogisticRegression and LogisticRegressionModel. 2. Directly support Matrix in Param.jsonEncode, similar to what we have done for Vector. After some discussion in jira, we prefer the fix to support Matrix as a valid Param type, for simplicity and convenience for other classes. Note that in the implementation, I added a "class" field in the JSON object to match different JSON converters when loading, which is for preciseness and future extension. ## How was this patch tested? new unit test to cover the LR case and JsonMatrixConverter Author: Yuhao Yang Closes #19525 from hhbyyh/lrsave. --- .../org/apache/spark/ml/linalg/Matrices.scala | 7 ++ .../spark/ml/linalg/JsonMatrixConverter.scala | 79 +++++++++++++++++++ .../org/apache/spark/ml/param/params.scala | 36 +++++++-- .../LogisticRegressionSuite.scala | 11 +++ .../ml/linalg/JsonMatrixConverterSuite.scala | 45 +++++++++++ 5 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/linalg/JsonMatrixConverterSuite.scala diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index 66c53624419d9..14428c6f45cce 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -476,6 +476,9 @@ class DenseMatrix @Since("2.0.0") ( @Since("2.0.0") object DenseMatrix { + private[ml] def unapply(dm: DenseMatrix): Option[(Int, Int, Array[Double], Boolean)] = + Some((dm.numRows, dm.numCols, dm.values, dm.isTransposed)) + /** * Generate a `DenseMatrix` consisting of zeros. * @param numRows number of rows of the matrix @@ -827,6 +830,10 @@ class SparseMatrix @Since("2.0.0") ( @Since("2.0.0") object SparseMatrix { + private[ml] def unapply( + sm: SparseMatrix): Option[(Int, Int, Array[Int], Array[Int], Array[Double], Boolean)] = + Some((sm.numRows, sm.numCols, sm.colPtrs, sm.rowIndices, sm.values, sm.isTransposed)) + /** * Generate a `SparseMatrix` from Coordinate List (COO) format. Input must be an array of * (i, j, value) tuples. Entries that have duplicate values of i and j are diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala new file mode 100644 index 0000000000000..0bee643412b3f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.linalg + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} + +private[ml] object JsonMatrixConverter { + + /** Unique class name for identifying JSON object encoded by this class. */ + val className = "matrix" + + /** + * Parses the JSON representation of a Matrix into a [[Matrix]]. + */ + def fromJson(json: String): Matrix = { + implicit val formats = DefaultFormats + val jValue = parseJson(json) + (jValue \ "type").extract[Int] match { + case 0 => // sparse + val numRows = (jValue \ "numRows").extract[Int] + val numCols = (jValue \ "numCols").extract[Int] + val colPtrs = (jValue \ "colPtrs").extract[Seq[Int]].toArray + val rowIndices = (jValue \ "rowIndices").extract[Seq[Int]].toArray + val values = (jValue \ "values").extract[Seq[Double]].toArray + val isTransposed = (jValue \ "isTransposed").extract[Boolean] + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) + case 1 => // dense + val numRows = (jValue \ "numRows").extract[Int] + val numCols = (jValue \ "numCols").extract[Int] + val values = (jValue \ "values").extract[Seq[Double]].toArray + val isTransposed = (jValue \ "isTransposed").extract[Boolean] + new DenseMatrix(numRows, numCols, values, isTransposed) + case _ => + throw new IllegalArgumentException(s"Cannot parse $json into a Matrix.") + } + } + + /** + * Coverts the Matrix to a JSON string. + */ + def toJson(m: Matrix): String = { + m match { + case SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) => + val jValue = ("class" -> className) ~ + ("type" -> 0) ~ + ("numRows" -> numRows) ~ + ("numCols" -> numCols) ~ + ("colPtrs" -> colPtrs.toSeq) ~ + ("rowIndices" -> rowIndices.toSeq) ~ + ("values" -> values.toSeq) ~ + ("isTransposed" -> isTransposed) + compact(render(jValue)) + case DenseMatrix(numRows, numCols, values, isTransposed) => + val jValue = ("class" -> className) ~ + ("type" -> 1) ~ + ("numRows" -> numRows) ~ + ("numCols" -> numCols) ~ + ("values" -> values.toSeq) ~ + ("isTransposed" -> isTransposed) + compact(render(jValue)) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8985f2af90a9a..1b4b401ac4aa0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -28,9 +28,9 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.ml.linalg.JsonVectorConverter -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector} import org.apache.spark.ml.util.Identifiable /** @@ -94,9 +94,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali compact(render(JString(x))) case v: Vector => JsonVectorConverter.toJson(v) + case m: Matrix => + JsonMatrixConverter.toJson(m) case _ => throw new NotImplementedError( - "The default jsonEncode only supports string and vector. " + + "The default jsonEncode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } } @@ -122,17 +124,35 @@ private[ml] object Param { /** Decodes a param value from JSON. */ def jsonDecode[T](json: String): T = { - parse(json) match { + val jValue = parse(json) + jValue match { case JString(x) => x.asInstanceOf[T] case JObject(v) => val keys = v.map(_._1) - assert(keys.contains("type") && keys.contains("values"), - s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") - JsonVectorConverter.fromJson(json).asInstanceOf[T] + if (keys.contains("class")) { + implicit val formats = DefaultFormats + val className = (jValue \ "class").extract[String] + className match { + case JsonMatrixConverter.className => + val checkFields = Array("numRows", "numCols", "values", "isTransposed", "type") + require(checkFields.forall(keys.contains), s"Expect a JSON serialized Matrix" + + s" but cannot find fields ${checkFields.mkString(", ")} in $json.") + JsonMatrixConverter.fromJson(json).asInstanceOf[T] + + case s => throw new SparkException(s"unrecognized class $s in $json") + } + } else { + // "class" info in JSON was added in Spark 2.3(SPARK-22289). JSON support for Vector was + // implemented before that and does not have "class" attribute. + require(keys.contains("type") && keys.contains("values"), s"Expect a JSON serialized" + + s" vector/matrix but cannot find fields 'type' and 'values' in $json.") + JsonVectorConverter.fromJson(json).asInstanceOf[T] + } + case _ => throw new NotImplementedError( - "The default jsonDecode only supports string and vector. " + + "The default jsonDecode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 14f550890d238..a5f81a38face9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2767,6 +2767,17 @@ class LogisticRegressionSuite val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, LogisticRegressionSuite.allParamSettings, checkModelData) + + // test lr with bounds on coefficients, need to set elasticNetParam to 0. + val numFeatures = smallBinaryDataset.select("features").head().getAs[Vector](0).size + val lowerBounds = new DenseMatrix(1, numFeatures, (1 to numFeatures).map(_ / 1000.0).toArray) + val upperBounds = new DenseMatrix(1, numFeatures, (1 to numFeatures).map(_ * 1000.0).toArray) + val paramSettings = Map("lowerBoundsOnCoefficients" -> lowerBounds, + "upperBoundsOnCoefficients" -> upperBounds, + "elasticNetParam" -> 0.0 + ) + testEstimatorAndModelReadWrite(lr, smallBinaryDataset, paramSettings, + paramSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonMatrixConverterSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonMatrixConverterSuite.scala new file mode 100644 index 0000000000000..4d83f945b4ba8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonMatrixConverterSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.json4s.jackson.JsonMethods.parse + +import org.apache.spark.SparkFunSuite + +class JsonMatrixConverterSuite extends SparkFunSuite { + + test("toJson/fromJson") { + val denseMatrices = Seq( + Matrices.dense(0, 0, Array.empty[Double]), + Matrices.dense(1, 1, Array(0.1)), + new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0), true) + ) + + val sparseMatrices = denseMatrices.map(_.toSparse) ++ Seq( + Matrices.sparse(3, 2, Array(0, 1, 2), Array(1, 2), Array(3.1, 3.5)) + ) + + for (m <- sparseMatrices ++ denseMatrices) { + val json = JsonMatrixConverter.toJson(m) + parse(json) // `json` should be a valid JSON string + val u = JsonMatrixConverter.fromJson(json) + assert(u.getClass === m.getClass, "toJson/fromJson should preserve Matrix types.") + assert(u === m, "toJson/fromJson should preserve Matrix values.") + } + } +} From 7a51e71355485bb176a1387d99ec430c5986cbec Mon Sep 17 00:00:00 2001 From: German Schiavon Date: Tue, 12 Dec 2017 11:46:57 -0800 Subject: [PATCH 400/712] [SPARK-22574][MESOS][SUBMIT] Check submission request parameters ## What changes were proposed in this pull request? It solves the problem when submitting a wrong CreateSubmissionRequest to Spark Dispatcher was causing a bad state of Dispatcher and making it inactive as a Mesos framework. https://issues.apache.org/jira/browse/SPARK-22574 ## How was this patch tested? All spark test passed successfully. It was tested sending a wrong request (without appArgs) before and after the change. The point is easy, check if the value is null before being accessed. This was before the change, leaving the dispatcher inactive: ``` Exception in thread "Thread-22" java.lang.NullPointerException at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.getDriverCommandValue(MesosClusterScheduler.scala:444) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.buildDriverCommand(MesosClusterScheduler.scala:451) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.org$apache$spark$scheduler$cluster$mesos$MesosClusterScheduler$$createTaskInfo(MesosClusterScheduler.scala:538) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler$$anonfun$scheduleTasks$1.apply(MesosClusterScheduler.scala:570) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler$$anonfun$scheduleTasks$1.apply(MesosClusterScheduler.scala:555) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.scheduleTasks(MesosClusterScheduler.scala:555) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.resourceOffers(MesosClusterScheduler.scala:621) ``` And after: ``` "message" : "Malformed request: org.apache.spark.deploy.rest.SubmitRestProtocolException: Validation of message CreateSubmissionRequest failed!\n\torg.apache.spark.deploy.rest.SubmitRestProtocolMessage.validate(SubmitRestProtocolMessage.scala:70)\n\torg.apache.spark.deploy.rest.SubmitRequestServlet.doPost(RestSubmissionServer.scala:272)\n\tjavax.servlet.http.HttpServlet.service(HttpServlet.java:707)\n\tjavax.servlet.http.HttpServlet.service(HttpServlet.java:790)\n\torg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:845)\n\torg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:583)\n\torg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\n\torg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:511)\n\torg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\n\torg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\n\torg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\n\torg.spark_project.jetty.server.Server.handle(Server.java:524)\n\torg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:319)\n\torg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:253)\n\torg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:273)\n\torg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:95)\n\torg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\n\torg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\n\torg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\n\tjava.lang.Thread.run(Thread.java:745)" ``` Author: German Schiavon Closes #19793 from Gschiavon/fix-submission-request. --- .../deploy/rest/SubmitRestProtocolRequest.scala | 2 ++ .../spark/deploy/rest/SubmitRestProtocolSuite.scala | 2 ++ .../spark/deploy/rest/mesos/MesosRestServer.scala | 12 ++++++++---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 0d50a768942ed..86ddf954ca128 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -46,6 +46,8 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") + assertFieldIsSet(appArgs, "appArgs") + assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean("spark.driver.supervise") assertPropertyIsNumeric("spark.driver.cores") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 725b8848bc052..75c50af23c66a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -86,6 +86,8 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" + message.appArgs = Array("two slices") + message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index ff60b88c6d533..bb8dfee165183 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -77,10 +77,16 @@ private[mesos] class MesosSubmitRequestServlet( private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { // Required fields, including the main class because python is not yet supported val appResource = Option(request.appResource).getOrElse { - throw new SubmitRestMissingFieldException("Application jar is missing.") + throw new SubmitRestMissingFieldException("Application jar 'appResource' is missing.") } val mainClass = Option(request.mainClass).getOrElse { - throw new SubmitRestMissingFieldException("Main class is missing.") + throw new SubmitRestMissingFieldException("Main class 'mainClass' is missing.") + } + val appArgs = Option(request.appArgs).getOrElse { + throw new SubmitRestMissingFieldException("Application arguments 'appArgs' are missing.") + } + val environmentVariables = Option(request.environmentVariables).getOrElse { + throw new SubmitRestMissingFieldException("Environment variables 'environmentVariables' are missing.") } // Optional fields @@ -91,8 +97,6 @@ private[mesos] class MesosSubmitRequestServlet( val superviseDriver = sparkProperties.get("spark.driver.supervise") val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") - val appArgs = request.appArgs - val environmentVariables = request.environmentVariables val name = request.sparkProperties.getOrElse("spark.app.name", mainClass) // Construct driver description From 704af4bd67d981a2efb5eb00110e19c5f7e7c924 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 12 Dec 2017 13:40:01 -0800 Subject: [PATCH 401/712] Revert "[SPARK-22574][MESOS][SUBMIT] Check submission request parameters" This reverts commit 7a51e71355485bb176a1387d99ec430c5986cbec. --- .../deploy/rest/SubmitRestProtocolRequest.scala | 2 -- .../spark/deploy/rest/SubmitRestProtocolSuite.scala | 2 -- .../spark/deploy/rest/mesos/MesosRestServer.scala | 12 ++++-------- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 86ddf954ca128..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -46,8 +46,6 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") - assertFieldIsSet(appArgs, "appArgs") - assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean("spark.driver.supervise") assertPropertyIsNumeric("spark.driver.cores") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 75c50af23c66a..725b8848bc052 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -86,8 +86,6 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" - message.appArgs = Array("two slices") - message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index bb8dfee165183..ff60b88c6d533 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -77,16 +77,10 @@ private[mesos] class MesosSubmitRequestServlet( private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { // Required fields, including the main class because python is not yet supported val appResource = Option(request.appResource).getOrElse { - throw new SubmitRestMissingFieldException("Application jar 'appResource' is missing.") + throw new SubmitRestMissingFieldException("Application jar is missing.") } val mainClass = Option(request.mainClass).getOrElse { - throw new SubmitRestMissingFieldException("Main class 'mainClass' is missing.") - } - val appArgs = Option(request.appArgs).getOrElse { - throw new SubmitRestMissingFieldException("Application arguments 'appArgs' are missing.") - } - val environmentVariables = Option(request.environmentVariables).getOrElse { - throw new SubmitRestMissingFieldException("Environment variables 'environmentVariables' are missing.") + throw new SubmitRestMissingFieldException("Main class is missing.") } // Optional fields @@ -97,6 +91,8 @@ private[mesos] class MesosSubmitRequestServlet( val superviseDriver = sparkProperties.get("spark.driver.supervise") val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables val name = request.sparkProperties.getOrElse("spark.app.name", mainClass) // Construct driver description From 17cdabb88761e67ca555299109f89afdf02a4280 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 13 Dec 2017 07:42:24 +0900 Subject: [PATCH 402/712] [SPARK-19809][SQL][TEST] NullPointerException on zero-size ORC file ## What changes were proposed in this pull request? Until 2.2.1, Spark raises `NullPointerException` on zero-size ORC files. Usually, these zero-size ORC files are generated by 3rd-party apps like Flume. ```scala scala> sql("create table empty_orc(a int) stored as orc location '/tmp/empty_orc'") $ touch /tmp/empty_orc/zero.orc scala> sql("select * from empty_orc").show java.lang.RuntimeException: serious problem at org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.generateSplitsInfo(OrcInputFormat.java:1021) ... Caused by: java.lang.NullPointerException at org.apache.hadoop.hive.ql.io.orc.OrcInputFormat$BISplitStrategy.getSplits(OrcInputFormat.java:560) ``` After [SPARK-22279](https://github.com/apache/spark/pull/19499), Apache Spark with the default configuration doesn't have this bug. Although Hive 1.2.1 library code path still has the problem, we had better have a test coverage on what we have now in order to prevent future regression on it. ## How was this patch tested? Pass a newly added test case. Author: Dongjoon Hyun Closes #19948 from dongjoon-hyun/SPARK-19809-EMPTY-FILE. --- .../sql/hive/execution/SQLQuerySuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f2562c33e2a6e..93c91d3fcb727 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2172,4 +2172,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-19809 NullPointerException on zero-size ORC file") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTempPath { dir => + withTable("spark_19809") { + sql(s"CREATE TABLE spark_19809(a int) STORED AS ORC LOCATION '$dir'") + Files.touch(new File(s"${dir.getCanonicalPath}", "zero.orc")) + + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 + checkAnswer(sql("SELECT * FROM spark_19809"), Seq.empty) + } + } + } + } + } + } } From b03af8b582b9b71b09eaf3a1c01d1b3ef5f072e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Dec 2017 17:37:01 -0800 Subject: [PATCH 403/712] [SPARK-21087][ML][FOLLOWUP] Sync SharedParamsCodeGen and sharedParams. ## What changes were proposed in this pull request? #19208 modified ```sharedParams.scala```, but didn't generated by ```SharedParamsCodeGen.scala```. This involves mismatch between them. ## How was this patch tested? Existing test. Author: Yanbo Liang Closes #19958 from yanboliang/spark-21087. --- .../spark/ml/param/shared/SharedParamsCodeGen.scala | 8 ++++---- .../apache/spark/ml/param/shared/sharedParams.scala | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c54062921fce6..a267bbc874322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -84,10 +84,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true), - ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " + - "sub-model will be available after fitting. If set to true, then all sub-models will be " + - "available. Warning: For large models, collecting all sub-models can cause OOMs on the " + - "Spark driver.", + ParamDesc[Boolean]("collectSubModels", "whether to collect a list of sub-models trained " + + "during tuning. If set to false, then only the single best sub-model will be available " + + "after fitting. If set to true, then all sub-models will be available. Warning: For " + + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true) ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 34aa38ac751fd..0004f085d10f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -470,15 +470,17 @@ trait HasAggregationDepth extends Params { } /** - * Trait for shared param collectSubModels (default: false). + * Trait for shared param collectSubModels (default: false). This trait may be changed or + * removed between minor versions. */ -private[ml] trait HasCollectSubModels extends Params { +@DeveloperApi +trait HasCollectSubModels extends Params { /** - * Param for whether to collect a list of sub-models trained during tuning. + * Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. * @group expertParam */ - final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning") + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver") setDefault(collectSubModels, false) From 4117786a87f9d7631dec58a8e7aef09403b20a27 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Dec 2017 10:29:14 +0800 Subject: [PATCH 404/712] [SPARK-22716][SQL] Avoid the creation of mutable states in addReferenceObj ## What changes were proposed in this pull request? We have two methods to reference an object `addReferenceMinorObj` and `addReferenceObj `. The latter creates a new global variable, which means new entries in the constant pool. The PR unifies the two method in a single `addReferenceObj` which returns the code to access the object in the `references` array and doesn't add new mutable states. ## How was this patch tested? added UTs. Author: Marco Gaido Closes #19916 from mgaido91/SPARK-22716. --- .../spark/sql/catalyst/expressions/Cast.scala | 8 +++---- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 22 ++++--------------- .../expressions/datetimeExpressions.scala | 20 ++++++++--------- .../sql/catalyst/expressions/literals.scala | 9 ++++---- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/objects/objects.scala | 8 +++---- .../expressions/CodeGenerationSuite.scala | 7 ++++++ .../expressions/MiscExpressionsSuite.scala | 1 - .../catalyst/expressions/ScalaUDFSuite.scala | 3 ++- .../aggregate/RowBasedHashMapGenerator.scala | 4 ++-- .../VectorizedHashMapGenerator.scala | 4 ++-- 12 files changed, 42 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b8d3661a00abc..5279d41278967 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -605,7 +605,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case _ => @@ -633,7 +633,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" case _ => @@ -713,7 +713,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") (c, evPrim, evNull) => s""" @@ -730,7 +730,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case _: IntegralType => (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" case DecimalType() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4d26d9819321b..a3cf7612016b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1000,7 +1000,7 @@ case class ScalaUDF( ctx: CodegenContext, ev: ExprCode): ExprCode = { val scalaUDF = ctx.freshName("scalaUDF") - val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName) + val scalaUDFRef = ctx.addReferenceObj("scalaUDFRef", this, scalaUDFClassName) // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5c9e604a8d293..4b8b16ff71ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -109,28 +109,14 @@ class CodegenContext { * * Returns the code to access it. * - * This is for minor objects not to store the object into field but refer it from the references - * field at the time of use because number of fields in class is limited so we should reduce it. + * This does not to store the object into field but refer it from the references field at the + * time of use because number of fields in class is limited so we should reduce it. */ - def addReferenceMinorObj(obj: Any, className: String = null): String = { + def addReferenceObj(objName: String, obj: Any, className: String = null): String = { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - s"(($clsName) references[$idx])" - } - - /** - * Add an object to `references`, create a class member to access it. - * - * Returns the name of class member. - */ - def addReferenceObj(name: String, obj: Any, className: String = null): String = { - val term = freshName(name) - val idx = references.length - references += obj - val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") - term + s"(($clsName) references[$idx] /* $objName */)" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index eaf8788888211..44d54a20844a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -226,7 +226,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)") } @@ -257,7 +257,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)") } @@ -288,7 +288,7 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)") } @@ -529,7 +529,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) .format(new java.util.Date($timestamp / 1000)))""" @@ -691,7 +691,7 @@ abstract class UnixTime }""") } case StringType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" @@ -715,7 +715,7 @@ abstract class UnixTime ${ev.value} = ${eval1.value} / 1000000L; }""") case DateType => - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) ev.copy(code = s""" @@ -827,7 +827,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ }""") } } else { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" @@ -969,7 +969,7 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)""" @@ -1065,7 +1065,7 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)""" @@ -1143,7 +1143,7 @@ case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Optio } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceMinorObj(timeZone) + val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r, $tz)""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index eaeaf08c37b4e..383203a209833 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -290,7 +290,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceMinorObj(v) + val boxedValue = ctx.addReferenceObj("boxedValue", v) val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" ev.copy(code = code) } else { @@ -299,7 +299,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceMinorObj(v) + val boxedValue = ctx.addReferenceObj("boxedValue", v) val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" ev.copy(code = code) } else { @@ -309,8 +309,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { ev.copy(code = "", value = s"($javaType)$value") case TimestampType | LongType => ev.copy(code = "", value = s"${value}L") - case other => - ev.copy(code = "", value = ctx.addReferenceMinorObj(value, ctx.javaType(dataType))) + case _ => + ev.copy(code = "", value = ctx.addReferenceObj("literal", value, + ctx.javaType(dataType))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b86e271fe2958..4b9006ab5b423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -81,7 +81,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. - val errMsgField = ctx.addReferenceMinorObj(errMsg) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 349afece84d5c..4bd395eadcf19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1123,7 +1123,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) expressions = childrenCodes, funcName = "createExternalRow", extraArguments = "Object[]" -> values :: Nil) - val schemaField = ctx.addReferenceMinorObj(schema) + val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" @@ -1310,7 +1310,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null. - val errMsgField = ctx.addReferenceMinorObj(errMsg) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val code = s""" ${childGen.code} @@ -1347,7 +1347,7 @@ case class GetExternalRowField( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. - val errMsgField = ctx.addReferenceMinorObj(errMsg) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -1387,7 +1387,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. - val errMsgField = ctx.addReferenceMinorObj(errMsg) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val input = child.genCode(ctx) val obj = input.value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 40bf29bb3b573..a969811019161 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -394,4 +394,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Map("add" -> Literal(1))).genCode(ctx) assert(ctx.mutableStates.isEmpty) } + + test("SPARK-22716: addReferenceObj should not add mutable states") { + val ctx = new CodegenContext + val foo = new Object() + ctx.addReferenceObj("foo", foo) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 4fe7b436982b1..facc863081303 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -43,5 +43,4 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Uuid()), 36) assert(evaluate(Uuid()) !== evaluate(Uuid())) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 70dea4b39d55d..b0687fe900d3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,6 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + // ScalaUDF can be very verbose and trigger reduceCodeSize + assert(ctx.mutableStates.forall(_._2.startsWith("globalIsNull"))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 3718424931b40..fd25707dd4ca6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -47,7 +47,7 @@ class RowBasedHashMapGenerator( val generatedKeySchema: String = s"new org.apache.spark.sql.types.StructType()" + groupingKeySchema.map { key => - val keyName = ctx.addReferenceMinorObj(key.name) + val keyName = ctx.addReferenceObj("keyName", key.name) key.dataType match { case d: DecimalType => s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( @@ -60,7 +60,7 @@ class RowBasedHashMapGenerator( val generatedValueSchema: String = s"new org.apache.spark.sql.types.StructType()" + bufferSchema.map { key => - val keyName = ctx.addReferenceMinorObj(key.name) + val keyName = ctx.addReferenceObj("keyName", key.name) key.dataType match { case d: DecimalType => s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index f04cd48072f17..0380ee8b09d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -54,7 +54,7 @@ class VectorizedHashMapGenerator( val generatedSchema: String = s"new org.apache.spark.sql.types.StructType()" + (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceMinorObj(key.name) + val keyName = ctx.addReferenceObj("keyName", key.name) key.dataType match { case d: DecimalType => s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( @@ -67,7 +67,7 @@ class VectorizedHashMapGenerator( val generatedAggBufferSchema: String = s"new org.apache.spark.sql.types.StructType()" + bufferSchema.map { key => - val keyName = ctx.addReferenceMinorObj(key.name) + val keyName = ctx.addReferenceObj("keyName", key.name) key.dataType match { case d: DecimalType => s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( From c7d0148615c921dca782ee3785b5d0cd59e42262 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Dec 2017 10:40:05 +0800 Subject: [PATCH 405/712] [SPARK-22600][SQL] Fix 64kb limit for deeply nested expressions under wholestage codegen ## What changes were proposed in this pull request? SPARK-22543 fixes the 64kb compile error for deeply nested expression for non-wholestage codegen. This PR extends it to support wholestage codegen. This patch brings some util methods in to extract necessary parameters for an expression if it is split to a function. The util methods are put in object `ExpressionCodegen` under `codegen`. The main entry is `getExpressionInputParams` which returns all necessary parameters to evaluate the given expression in a split function. This util methods can be used to split expressions too. This is a TODO item later. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #19813 from viirya/reduce-expr-code-for-wholestage. --- .../sql/catalyst/expressions/Expression.scala | 37 ++- .../expressions/codegen/CodeGenerator.scala | 44 ++- .../codegen/ExpressionCodegen.scala | 269 ++++++++++++++++++ .../codegen/ExpressionCodegenSuite.scala | 220 ++++++++++++++ .../sql/execution/ColumnarBatchScan.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 23 +- 6 files changed, 585 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e9..329ea5d421509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,6 +105,12 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) + eval.isNull = if (this.nullable) eval.isNull else "false" + + // Records current input row and variables of this expression. + eval.inputRow = ctx.INPUT_ROW + eval.inputVars = findInputVars(ctx, eval) + reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -115,9 +121,29 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Returns the input variables to this expression. + */ + private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { + if (ctx.currentVars != null) { + this.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + ExprInputVar(exprCode = ctx.currentVars(ordinal), + dataType = b.dataType, nullable = b.nullable) + } + } else { + Seq.empty + } + } + + /** + * In order to prevent 64kb compile error, reducing the size of generated codes by + * separating it into a function if the size exceeds a threshold. + */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) + + if (eval.code.trim.length > 1024 && funcParams.isDefined) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -132,9 +158,12 @@ abstract class Expression extends TreeNode[Expression] { val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) + val callParams = funcParams.map(_._1.mkString(", ")).get + val declParams = funcParams.map(_._2.mkString(", ")).get + val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName($declParams) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -142,7 +171,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = s"$javaType $newValue = $funcFullName($callParams);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4b8b16ff71ecb..257c3f10fa08b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,8 +55,24 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. + * @param inputRow A term that holds the input row name when generating this code. + * @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode( + var code: String, + var isNull: String, + var value: String, + var inputRow: String = null, + var inputVars: Seq[ExprInputVar] = Seq.empty) + +/** + * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. + * + * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. + * @param dataType The data type of the input variable. + * @param nullable Whether the input variable can be null or not. + */ +case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean) /** * State used for subexpression elimination. @@ -1012,16 +1028,25 @@ class CodegenContext { commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" + val isNull = if (expr.nullable) { + s"${fnName}IsNull" + } else { + "" + } val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) + val assignIsNull = if (expr.nullable) { + s"$isNull = ${eval.isNull};" + } else { + "" + } val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $isNull = ${eval.isNull}; + | $assignIsNull | $value = ${eval.value}; |} """.stripMargin @@ -1039,12 +1064,17 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") + if (expr.nullable) { + addMutableState(JAVA_BOOLEAN, isNull) + } + addMutableState(javaType(expr.dataType), value) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = if (expr.nullable) { + SubExprEliminationState(isNull, value) + } else { + SubExprEliminationState("false", value) + } e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala new file mode 100644 index 0000000000000..a2dda48e951d1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + +/** + * Defines util methods used in expression code generation. + */ +object ExpressionCodegen { + + /** + * Given an expression, returns the all necessary parameters to evaluate it, so the generated + * code of this expression can be split in a function. + * The 1st string in returned tuple is the parameter strings used to call the function. + * The 2nd string in returned tuple is the parameter strings used to declare the function. + * + * Returns `None` if it can't produce valid parameters. + * + * Params to include: + * 1. Evaluated columns referred by this, children or deferred expressions. + * 2. Rows referred by this, children or deferred expressions. + * 3. Eliminated subexpressions referred by children expressions. + */ + def getExpressionInputParams( + ctx: CodegenContext, + expr: Expression): Option[(Seq[String], Seq[String])] = { + val subExprs = getSubExprInChildren(ctx, expr) + val subExprCodes = getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) => + ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable) + } + val paramsFromSubExprs = prepareFunctionParams(ctx, subVars) + + val inputVars = getInputVarsForChildren(ctx, expr) + val paramsFromColumns = prepareFunctionParams(ctx, inputVars) + + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => + (row, s"InternalRow $row") + } + + val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length + // Maximum allowed parameter number for Java's method descriptor. + if (paramsLength > 255) { + None + } else { + val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip + val callParams = allParams._1.distinct + val declParams = allParams._2.distinct + Some((callParams, declParams)) + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { + expr.children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + }.distinct + } + + /** + * A small helper function to return `ExprCode`s that represent subexpressions. + */ + def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { + subExprs.map { subExpr => + val state = ctx.subExprEliminationExprs(subExpr) + ExprCode(code = "", value = state.value, isNull = state.isNull) + } + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { + expr.children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + val curExprCode = exprCodes.dequeue() + if (curExprCode.inputRow != null) { + inputRows += curExprCode.inputRow + } + curExprCode.inputVars.foreach { inputVar => + if (!isEvaluated(inputVar.exprCode)) { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputRows + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + def getInputVarsForChildren( + ctx: CodegenContext, + expr: Expression): Seq[ExprInputVar] = { + expr.children.flatMap(getInputVars(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. + */ + def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = { + if (ctx.currentVars == null) { + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + isEvaluated(ctx.currentVars(ordinal)) => + Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we track down further. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputVars = mutable.ArrayBuffer.empty[ExprInputVar] + + while (exprCodes.nonEmpty) { + exprCodes.dequeue().inputVars.foreach { inputVar => + if (isEvaluated(inputVar.exprCode)) { + inputVars += inputVar + } else { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputVars + } + + /** + * Helper function to calculate the size of an expression as function parameter. + */ + def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = { + (if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match { + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2 + case _ => 1 + } + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length of 255 or less. `this` contributes one unit and a parameter of type long or double + * contributes two units. + */ + def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = { + // Initial value is 1 for `this`. + 1 + inputs.map(calculateParamLength(ctx, _)).sum + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + def prepareFunctionParams( + ctx: CodegenContext, + inputVars: Seq[ExprInputVar]): Seq[(String, String)] = { + inputVars.flatMap { inputVar => + val params = mutable.ArrayBuffer.empty[(String, String)] + val ev = inputVar.exprCode + + // Only include the expression value if it is not a literal. + if (!isLiteral(ev)) { + val argType = ctx.javaType(inputVar.dataType) + params += ((ev.value, s"$argType ${ev.value}")) + } + + // If it is a nullable expression and `isNull` is not a literal. + if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") { + params += ((ev.isNull, s"boolean ${ev.isNull}")) + } + + params + }.distinct + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * Returns true if this value is a literal. + */ + def isLiteral(exprCode: ExprCode): Boolean = { + assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.") + + if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") { + true + } else { + // The valid characters for the first character of a Java variable is [a-zA-Z_$]. + exprCode.value.head match { + case v if v >= 'a' && v <= 'z' => false + case v if v >= 'A' && v <= 'Z' => false + case '_' | '$' => false + case _ => true + } + } + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * The code is emptied after evaluation. + */ + def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == "" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala new file mode 100644 index 0000000000000..39d58cabff228 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionCodegenSuite extends SparkFunSuite { + + test("Returns eliminated subexpressions for expression") { + val ctx = new CodegenContext() + val subExpr = Add(Literal(1), Literal(2)) + val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4))) + + ctx.generateExpressions(exprs, doSubexpressionElimination = true) + val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0)) + assert(subexpressions.length == 1 && subexpressions(0) == subExpr) + } + + test("Gets parameters for subexpressions") { + val ctx = new CodegenContext() + val subExprs = Seq( + Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable + Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable + + ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1")) + ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) + + val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) => + ExprInputVar(exprCode, expr.dataType, expr.nullable) + } + val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: current variables") { + val ctx = new CodegenContext() + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), // evaluated + ExprCode("", isNull = "isNull2", value = "value2"), // evaluated + ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + BoundReference(2, IntegerType, nullable = true)) + + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + // Only two evaluated variables included. + assert(inputVars.length == 2) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false) + assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true) + assert(inputVars(0).exprCode == currentVars(0)) + assert(inputVars(1).exprCode == currentVars(1)) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: deferred variables") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an evaluated column from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + + // currentVars(0) depends on this evaluated column. + currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"), + dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 1) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 2) + assert(params(0) == Tuple2("value2", "int value2")) + assert(params(1) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input rows for expression") { + val ctx = new CodegenContext() + ctx.currentVars = null + ctx.INPUT_ROW = "i" + + val expr = Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "i") + } + + test("Returns input rows for expression: deferred expression") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an input row from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + currentVars(0).inputRow = "inputadaptor_row1" + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "inputadaptor_row1") + } + + test("Returns both input rows and variables for expression") { + val ctx = new CodegenContext() + // 5 input variables in currentVars: + // 1 evaluated variable (value1). + // 3 not evaluated variables. + // value2 depends on an evaluated column from other operator. + // value3 depends on an input row from other operator. + // value4 depends on a not evaluated yet column from other operator. + // 1 null indicating to use input row "i". + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), + ExprCode("fake code;", isNull = "isNull2", value = "value2"), + ExprCode("fake code;", isNull = "isNull3", value = "value3"), + ExprCode("fake code;", isNull = "isNull4", value = "value4"), + null) + // value2 depends on this evaluated column. + currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"), + dataType = IntegerType, nullable = true)) + // value3 depends on an input row "inputadaptor_row1". + currentVars(2).inputRow = "inputadaptor_row1" + // value4 depends on another not evaluated yet column. + currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;", + isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = "i" + + // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] } + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + Add(Add(BoundReference(2, IntegerType, nullable = true), + BoundReference(3, IntegerType, nullable = true)), + BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i". + + // input rows: "i", "inputadaptor_row1". + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 2) + assert(inputRows(0) == "inputadaptor_row1") + assert(inputRows(1) == "i") + + // input variables: value1 and value5 + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 2) + + // value1 has inlined isNull "false", so don't need to include it in the params. + val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(inputVarParams.length == 3) + assert(inputVarParams(0) == Tuple2("value1", "int value1")) + assert(inputVarParams(1) == Tuple2("value5", "int value5")) + assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5")) + } + + test("isLiteral: literals") { + val literals = Seq( + ExprCode("", "", "true"), + ExprCode("", "", "false"), + ExprCode("", "", "1"), + ExprCode("", "", "-1"), + ExprCode("", "", "1L"), + ExprCode("", "", "-1L"), + ExprCode("", "", "1.0f"), + ExprCode("", "", "-1.0f"), + ExprCode("", "", "0.1f"), + ExprCode("", "", "-0.1f"), + ExprCode("", "", """"string""""), + ExprCode("", "", "(byte)-1"), + ExprCode("", "", "(short)-1"), + ExprCode("", "", "null")) + + literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true)) + } + + test("isLiteral: non literals") { + val variables = Seq( + ExprCode("", "", "var1"), + ExprCode("", "", "_var2"), + ExprCode("", "", "$var3"), + ExprCode("", "", "v1a2r3"), + ExprCode("", "", "_1v2a3r"), + ExprCode("", "", "$1v2a3r")) + + variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbdea..05186c4472566 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -108,7 +108,10 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null + // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it. + // So making it as global variable. val rowidx = ctx.freshName("rowIdx") + ctx.addMutableState(ctx.JAVA_INT, rowidx) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } @@ -128,7 +131,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; + | $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..1281169b607c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -236,4 +237,24 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = Seq(("abc", 1)).toDF("key", "int") + df.write.parquet(path) + + var strExpr: Expression = col("key").expr + for (_ <- 1 to 150) { + strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8")) + } + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + + val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) + val plan = df2.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) + df2.collect() + } + } } From 0e36ba6212bc24b3185e385914fbf2d62cbfb6da Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 12 Dec 2017 21:28:24 -0800 Subject: [PATCH 406/712] [SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test ## What changes were proposed in this pull request? We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues. I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions. I only modify one testcase in `LinearRegressionSuite`, for first pass review. Link to #19746 ## How was this patch tested? `MLTestSuite` added. Author: WeichenXu Closes #19843 from WeichenXu123/ml_stream_test_helper. --- mllib/pom.xml | 14 +++ .../ml/regression/LinearRegressionSuite.scala | 8 +- .../org/apache/spark/ml/util/MLTest.scala | 91 +++++++++++++++++++ .../apache/spark/ml/util/MLTestSuite.scala | 47 ++++++++++ .../spark/sql/streaming/StreamTest.scala | 67 +++++++++----- .../spark/sql/test/TestSQLContext.scala | 2 +- 6 files changed, 203 insertions(+), 26 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala diff --git a/mllib/pom.xml b/mllib/pom.xml index 925b5422a54cc..a906c9e02cd4c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -60,6 +60,20 @@ spark-sql_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-graphx_${scala.binary.version} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 0e0be58dbf022..aec5ac0e75896 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -233,7 +232,8 @@ class LinearRegressionSuite assert(model2.intercept ~== interceptR relTol 1E-3) assert(model2.coefficients ~= coefficientsR relTol 1E-3) - model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala new file mode 100644 index 0000000000000..7a5426ebadaa5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.util.Utils + +trait MLTest extends StreamTest with TempDirectory { self: Suite => + + @transient var sc: SparkContext = _ + @transient var checkpointDir: String = _ + + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf)) + } + + override def beforeAll(): Unit = { + super.beforeAll() + sc = spark.sparkContext + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) + } + + override def afterAll() { + try { + Utils.deleteRecursively(new File(checkpointDir)) + } finally { + super.afterAll() + } + } + + def testTransformerOnStreamData[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (checkFunction: Row => Unit): Unit = { + + val columnNames = dataframe.schema.fieldNames + val stream = MemoryStream[A] + val streamDF = stream.toDS().toDF(columnNames: _*) + + val data = dataframe.as[A].collect() + + val streamOutput = transformer.transform(streamDF) + .select(firstResultCol, otherResultCols: _*) + testStream(streamOutput) ( + AddData(stream, data: _*), + CheckAnswer(checkFunction) + ) + } + + def testTransformer[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (checkFunction: Row => Unit): Unit = { + testTransformerOnStreamData(dataframe, transformer, firstResultCol, + otherResultCols: _*)(checkFunction) + + val dfOutput = transformer.transform(dataframe) + dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row => + checkFunction(row) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala new file mode 100644 index 0000000000000..56217ec4f3b0c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.Row + +class MLTestSuite extends MLTest { + + import testImplicits._ + + test("test transformer on stream data") { + + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f")) + .toDF("id", "label") + val indexer = new StringIndexer().setStringOrderType("alphabetAsc") + .setInputCol("label").setOutputCol("indexed") + val indexerModel = indexer.fit(data) + testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + case Row(id: Int, indexed: Double) => + assert(id === indexed.toInt) + } + + intercept[Exception] { + testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + case Row(id: Int, indexed: Double) => + assert(id != indexed.toInt) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e68fca050571f..dc5b998ad68b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false) + + def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(checkFunction, false) } /** @@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false) + + def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(checkFunction, true) } case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean) @@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } + case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${checkFunction.toString()}" + private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } + def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap + + // Block until all data added has been processed for all the source + awaiting.foreach { case (sourceIndex, offset) => + failAfter(streamingTimeout) { + currentStream.awaitOffset(indexToSource(sourceIndex), offset) + } + } + + try if (lastOnly) sink.latestBatchData else sink.allData catch { + case e: Exception => + failTest("Exception while getting data from sink", e) + } + } + var manualClockExpectedTime = -1L val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be e.runAction() case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap - - // Block until all data added has been processed for all the source - awaiting.foreach { case (sourceIndex, offset) => - failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) - } - } - - val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch { - case e: Exception => - failTest("Exception while getting data from sink", e) - } - + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { error => failTest(error) } + + case CheckAnswerRowsByFunc(checkFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + sparkAnswer.foreach { row => + try { + checkFunction(row) + } catch { + case e: Throwable => failTest(e.toString) + } + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 959edf9a49371..4286e8a6ca2c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf /** * A special `SparkSession` prepared for testing. */ -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => +private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) From 6b80ce4fb20da57c9513b94ab02b53a5fd7571d0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 12 Dec 2017 22:41:38 -0800 Subject: [PATCH 407/712] [SPARK-19809][SQL][TEST][FOLLOWUP] Move the test case to HiveOrcQuerySuite ## What changes were proposed in this pull request? As a follow-up of #19948 , this PR moves the test case and adds comments. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #19960 from dongjoon-hyun/SPARK-19809-2. --- .../sql/hive/execution/SQLQuerySuite.scala | 36 -------------- .../sql/hive/orc/HiveOrcQuerySuite.scala | 48 ++++++++++++++++++- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 93c91d3fcb727..c11e37a516646 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2153,40 +2153,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - test("SPARK-22267 Spark SQL incorrectly reads ORC files when column order is different") { - Seq("native", "hive").foreach { orcImpl => - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { - withTempPath { f => - val path = f.getCanonicalPath - Seq(1 -> 2).toDF("c1", "c2").write.orc(path) - checkAnswer(spark.read.orc(path), Row(1, 2)) - - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 - withTable("t") { - sql(s"CREATE EXTERNAL TABLE t(c2 INT, c1 INT) STORED AS ORC LOCATION '$path'") - checkAnswer(spark.table("t"), Row(2, 1)) - } - } - } - } - } - } - - test("SPARK-19809 NullPointerException on zero-size ORC file") { - Seq("native", "hive").foreach { orcImpl => - withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { - withTempPath { dir => - withTable("spark_19809") { - sql(s"CREATE TABLE spark_19809(a int) STORED AS ORC LOCATION '$dir'") - Files.touch(new File(s"${dir.getCanonicalPath}", "zero.orc")) - - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 - checkAnswer(sql("SELECT * FROM spark_19809"), Seq.empty) - } - } - } - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 7244c369bd3f4..92b2f069cacd6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.hive.orc -import org.apache.spark.sql.AnalysisException +import java.io.File + +import com.google.common.io.Files + +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.orc.OrcQueryTest @@ -162,4 +166,46 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + // Since Hive 1.2.1 library code path still has this problem, users may hit this + // when spark.sql.hive.convertMetastoreOrc=false. However, after SPARK-22279, + // Apache Spark with the default configuration doesn't hit this bug. + test("SPARK-22267 Spark SQL incorrectly reads ORC files when column order is different") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTempPath { f => + val path = f.getCanonicalPath + Seq(1 -> 2).toDF("c1", "c2").write.orc(path) + checkAnswer(spark.read.orc(path), Row(1, 2)) + + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 + withTable("t") { + sql(s"CREATE EXTERNAL TABLE t(c2 INT, c1 INT) STORED AS ORC LOCATION '$path'") + checkAnswer(spark.table("t"), Row(2, 1)) + } + } + } + } + } + } + + // Since Hive 1.2.1 library code path still has this problem, users may hit this + // when spark.sql.hive.convertMetastoreOrc=false. However, after SPARK-22279, + // Apache Spark with the default configuration doesn't hit this bug. + test("SPARK-19809 NullPointerException on zero-size ORC file") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> orcImpl) { + withTempPath { dir => + withTable("spark_19809") { + sql(s"CREATE TABLE spark_19809(a int) STORED AS ORC LOCATION '$dir'") + Files.touch(new File(s"${dir.getCanonicalPath}", "zero.orc")) + + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { // default since 2.3.0 + checkAnswer(spark.table("spark_19809"), Seq.empty) + } + } + } + } + } + } } From 13e489b6754f4d3569dad99bf5be2d5b0914dd68 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 12 Dec 2017 22:48:31 -0800 Subject: [PATCH 408/712] [SPARK-22759][SQL] Filters can be combined iff both are deterministic ## What changes were proposed in this pull request? The query execution/optimization does not guarantee the expressions are evaluated in order. We only can combine them if and only if both are deterministic. We need to update the optimizer rule: CombineFilters. ## How was this patch tested? Updated the existing tests. Author: gatorsmile Closes #19947 from gatorsmile/combineFilters. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 +++- .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 12 ++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 484cd8c2475f5..577693506ed34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -695,7 +695,9 @@ object CombineUnions extends Rule[LogicalPlan] { */ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(fc, nf @ Filter(nc, grandChild)) => + // The query execution/optimization does not guarantee the expressions are evaluated in order. + // We only can combine them if and only if both are deterministic. + case Filter(fc, nf @ Filter(nc, grandChild)) if fc.deterministic && nc.deterministic => (ExpressionSet(splitConjunctivePredicates(fc)) -- ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { case Some(ac) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index de0e7c7ee49ac..641824e6698f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,19 +94,15 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("combine redundant deterministic filters") { + test("do not combine non-deterministic filters even if they are identical") { val originalQuery = testRelation .where(Rand(0) > 0.1 && 'a === 1) - .where(Rand(0) > 0.1 && 'a === 1) + .where(Rand(0) > 0.1 && 'a === 1).analyze - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1) - .analyze + val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery) } test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { From bdb5e55c2a67d16a36ad6baa22296d714d3525af Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Dec 2017 14:49:15 +0800 Subject: [PATCH 409/712] [SPARK-21322][SQL][FOLLOWUP] support histogram in filter cardinality estimation ## What changes were proposed in this pull request? some code cleanup/refactor and naming improvement. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19952 from cloud-fan/minor. --- .../statsEstimation/EstimationUtils.scala | 109 +++++++------ .../statsEstimation/FilterEstimation.scala | 152 +++++++++--------- 2 files changed, 134 insertions(+), 127 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 2f416f2af91f1..6f868cbd072c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -115,14 +115,10 @@ object EstimationUtils { } /** - * Returns the number of the first bin into which a column value falls for a specified + * Returns the index of the first bin into which the given value falls for a specified * numeric equi-height histogram. - * - * @param value a literal value of a column - * @param bins an array of bins for a given numeric equi-height histogram - * @return the id of the first bin into which a column value falls. */ - def findFirstBinForValue(value: Double, bins: Array[HistogramBin]): Int = { + private def findFirstBinForValue(value: Double, bins: Array[HistogramBin]): Int = { var i = 0 while ((i < bins.length) && (value > bins(i).hi)) { i += 1 @@ -131,14 +127,10 @@ object EstimationUtils { } /** - * Returns the number of the last bin into which a column value falls for a specified + * Returns the index of the last bin into which the given value falls for a specified * numeric equi-height histogram. - * - * @param value a literal value of a column - * @param bins an array of bins for a given numeric equi-height histogram - * @return the id of the last bin into which a column value falls. */ - def findLastBinForValue(value: Double, bins: Array[HistogramBin]): Int = { + private def findLastBinForValue(value: Double, bins: Array[HistogramBin]): Int = { var i = bins.length - 1 while ((i >= 0) && (value < bins(i).lo)) { i -= 1 @@ -147,65 +139,76 @@ object EstimationUtils { } /** - * Returns a percentage of a bin holding values for column value in the range of - * [lowerValue, higherValue] - * - * @param higherValue a given upper bound value of a specified column value range - * @param lowerValue a given lower bound value of a specified column value range - * @param bin a single histogram bin - * @return the percentage of a single bin holding values in [lowerValue, higherValue]. + * Returns the possibility of the given histogram bin holding values within the given range + * [lowerBound, upperBound]. */ - private def getOccupation( - higherValue: Double, - lowerValue: Double, + private def binHoldingRangePossibility( + upperBound: Double, + lowerBound: Double, bin: HistogramBin): Double = { - assert(bin.lo <= lowerValue && lowerValue <= higherValue && higherValue <= bin.hi) + assert(bin.lo <= lowerBound && lowerBound <= upperBound && upperBound <= bin.hi) if (bin.hi == bin.lo) { // the entire bin is covered in the range 1.0 - } else if (higherValue == lowerValue) { + } else if (upperBound == lowerBound) { // set percentage to 1/NDV 1.0 / bin.ndv.toDouble } else { // Use proration since the range falls inside this bin. - math.min((higherValue - lowerValue) / (bin.hi - bin.lo), 1.0) + math.min((upperBound - lowerBound) / (bin.hi - bin.lo), 1.0) } } /** - * Returns the number of bins for column values in [lowerValue, higherValue]. - * The column value distribution is saved in an equi-height histogram. The return values is a - * double value is because we may return a portion of a bin. For example, a predicate - * "column = 8" may return the number of bins 0.2 if the holding bin has 5 distinct values. + * Returns the number of histogram bins holding values within the given range + * [lowerBound, upperBound]. + * + * Note that the returned value is double type, because the range boundaries usually occupy a + * portion of a bin. An extreme case is [value, value] which is generated by equal predicate + * `col = value`, we can get higher accuracy by allowing returning portion of histogram bins. * - * @param higherId id of the high end bin holding the high end value of a column range - * @param lowerId id of the low end bin holding the low end value of a column range - * @param higherEnd a given upper bound value of a specified column value range - * @param lowerEnd a given lower bound value of a specified column value range - * @param histogram a numeric equi-height histogram - * @return the number of bins for column values in [lowerEnd, higherEnd]. + * @param upperBound the highest value of the given range + * @param upperBoundInclusive whether the upperBound is included in the range + * @param lowerBound the lowest value of the given range + * @param lowerBoundInclusive whether the lowerBound is included in the range + * @param bins an array of bins for a given numeric equi-height histogram */ - def getOccupationBins( - higherId: Int, - lowerId: Int, - higherEnd: Double, - lowerEnd: Double, - histogram: Histogram): Double = { - assert(lowerId <= higherId) - - if (lowerId == higherId) { - val curBin = histogram.bins(lowerId) - getOccupation(higherEnd, lowerEnd, curBin) + def numBinsHoldingRange( + upperBound: Double, + upperBoundInclusive: Boolean, + lowerBound: Double, + lowerBoundInclusive: Boolean, + bins: Array[HistogramBin]): Double = { + assert(bins.head.lo <= lowerBound && lowerBound <= upperBound && upperBound <= bins.last.hi, + "Given range does not fit in the given histogram.") + assert(upperBound != lowerBound || upperBoundInclusive || lowerBoundInclusive, + s"'$lowerBound < value < $upperBound' is an invalid range.") + + val upperBinIndex = if (upperBoundInclusive) { + findLastBinForValue(upperBound, bins) + } else { + findFirstBinForValue(upperBound, bins) + } + val lowerBinIndex = if (lowerBoundInclusive) { + findFirstBinForValue(lowerBound, bins) + } else { + findLastBinForValue(lowerBound, bins) + } + assert(lowerBinIndex <= upperBinIndex, "Invalid histogram data.") + + + if (lowerBinIndex == upperBinIndex) { + binHoldingRangePossibility(upperBound, lowerBound, bins(lowerBinIndex)) } else { - // compute how much lowerEnd/higherEnd occupies its bin - val lowerCurBin = histogram.bins(lowerId) - val lowerPart = getOccupation(lowerCurBin.hi, lowerEnd, lowerCurBin) + // Computes the occupied portion of bins of the upperBound and lowerBound. + val lowerBin = bins(lowerBinIndex) + val lowerPart = binHoldingRangePossibility(lowerBin.hi, lowerBound, lowerBin) - val higherCurBin = histogram.bins(higherId) - val higherPart = getOccupation(higherEnd, higherCurBin.lo, higherCurBin) + val higherBin = bins(upperBinIndex) + val higherPart = binHoldingRangePossibility(upperBound, higherBin.lo, higherBin) - // the total length is lowerPart + higherPart + bins between them - lowerPart + higherPart + higherId - lowerId - 1 + // The total number of bins is lowerPart + higherPart + bins between them + lowerPart + higherPart + upperBinIndex - lowerBinIndex - 1 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index f52a15edbfe4a..850dd1ba724a0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -336,43 +336,12 @@ case class FilterEstimation(plan: Filter) extends Logging { // returns 1/ndv if there is no histogram Some(1.0 / BigDecimal(ndv)) } else { - // We compute filter selectivity using Histogram information. - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble - val histogram = colStat.histogram.get - val hgmBins = histogram.bins - - // find bins where column's current min and max locate. Note that a column's [min, max] - // range may change due to another condition applied earlier. - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble - val minBinId = EstimationUtils.findFirstBinForValue(min, hgmBins) - val maxBinId = EstimationUtils.findLastBinForValue(max, hgmBins) - - // compute how many bins the column's current valid range [min, max] occupies. - // Note that a column's [min, max] range may vary after we apply some filter conditions. - val validRangeBins = EstimationUtils.getOccupationBins(maxBinId, minBinId, max, - min, histogram) - - val lowerBinId = EstimationUtils.findFirstBinForValue(datum, hgmBins) - val higherBinId = EstimationUtils.findLastBinForValue(datum, hgmBins) - assert(lowerBinId <= higherBinId) - val lowerBinNdv = hgmBins(lowerBinId).ndv - val higherBinNdv = hgmBins(higherBinId).ndv - // assume uniform distribution in each bin - val occupiedBins = if (lowerBinId == higherBinId) { - 1.0 / lowerBinNdv - } else { - (1.0 / lowerBinNdv) + // lowest bin - (higherBinId - lowerBinId - 1) + // middle bins - (1.0 / higherBinNdv) // highest bin - } - Some(occupiedBins / validRangeBins) + Some(computeEqualityPossibilityByHistogram(literal, colStat)) } } else { // not in interval Some(0.0) } - } /** @@ -542,11 +511,7 @@ case class FilterEstimation(plan: Filter) extends Logging { } } } else { - val numericHistogram = colStat.histogram.get - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - percent = computePercentByEquiHeightHgm(op, numericHistogram, max, min, datum) + percent = computeComparisonPossibilityByHistogram(op, literal, colStat) } if (update) { @@ -574,51 +539,90 @@ case class FilterEstimation(plan: Filter) extends Logging { } /** - * Returns the selectivity percentage for binary condition in the column's - * current valid range [min, max] - * - * @param op a binary comparison operator - * @param histogram a numeric equi-height histogram - * @param max the upper bound of the current valid range for a given column - * @param min the lower bound of the current valid range for a given column - * @param datumNumber the numeric value of a literal - * @return the selectivity percentage for a condition in the current range. + * Computes the possibility of an equality predicate using histogram. */ + private def computeEqualityPossibilityByHistogram( + literal: Literal, colStat: ColumnStat): Double = { + val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val histogram = colStat.histogram.get - def computePercentByEquiHeightHgm( - op: BinaryComparison, - histogram: Histogram, - max: Double, - min: Double, - datumNumber: Double): Double = { // find bins where column's current min and max locate. Note that a column's [min, max] // range may change due to another condition applied earlier. - val minBinId = EstimationUtils.findFirstBinForValue(min, histogram.bins) - val maxBinId = EstimationUtils.findLastBinForValue(max, histogram.bins) + val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble + val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble // compute how many bins the column's current valid range [min, max] occupies. - // Note that a column's [min, max] range may vary after we apply some filter conditions. - val minToMaxLength = EstimationUtils.getOccupationBins(maxBinId, minBinId, max, min, histogram) - - val datumInBinId = op match { - case LessThan(_, _) | GreaterThanOrEqual(_, _) => - EstimationUtils.findFirstBinForValue(datumNumber, histogram.bins) - case LessThanOrEqual(_, _) | GreaterThan(_, _) => - EstimationUtils.findLastBinForValue(datumNumber, histogram.bins) - } + val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( + upperBound = max, + upperBoundInclusive = true, + lowerBound = min, + lowerBoundInclusive = true, + histogram.bins) + + val numBinsHoldingDatum = EstimationUtils.numBinsHoldingRange( + upperBound = datum, + upperBoundInclusive = true, + lowerBound = datum, + lowerBoundInclusive = true, + histogram.bins) + + numBinsHoldingDatum / numBinsHoldingEntireRange + } - op match { - // LessThan and LessThanOrEqual share the same logic, - // but their datumInBinId may be different - case LessThan(_, _) | LessThanOrEqual(_, _) => - EstimationUtils.getOccupationBins(datumInBinId, minBinId, datumNumber, min, - histogram) / minToMaxLength - // GreaterThan and GreaterThanOrEqual share the same logic, - // but their datumInBinId may be different - case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - EstimationUtils.getOccupationBins(maxBinId, datumInBinId, max, datumNumber, - histogram) / minToMaxLength + /** + * Computes the possibility of a comparison predicate using histogram. + */ + private def computeComparisonPossibilityByHistogram( + op: BinaryComparison, literal: Literal, colStat: ColumnStat): Double = { + val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val histogram = colStat.histogram.get + + // find bins where column's current min and max locate. Note that a column's [min, max] + // range may change due to another condition applied earlier. + val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble + val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + + // compute how many bins the column's current valid range [min, max] occupies. + val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( + max, upperBoundInclusive = true, min, lowerBoundInclusive = true, histogram.bins) + + val numBinsHoldingRange = op match { + // LessThan and LessThanOrEqual share the same logic, the only difference is whether to + // include the upperBound in the range. + case _: LessThan => + EstimationUtils.numBinsHoldingRange( + upperBound = datum, + upperBoundInclusive = false, + lowerBound = min, + lowerBoundInclusive = true, + histogram.bins) + case _: LessThanOrEqual => + EstimationUtils.numBinsHoldingRange( + upperBound = datum, + upperBoundInclusive = true, + lowerBound = min, + lowerBoundInclusive = true, + histogram.bins) + + // GreaterThan and GreaterThanOrEqual share the same logic, the only difference is whether to + // include the lowerBound in the range. + case _: GreaterThan => + EstimationUtils.numBinsHoldingRange( + upperBound = max, + upperBoundInclusive = true, + lowerBound = datum, + lowerBoundInclusive = false, + histogram.bins) + case _: GreaterThanOrEqual => + EstimationUtils.numBinsHoldingRange( + upperBound = max, + upperBoundInclusive = true, + lowerBound = datum, + lowerBoundInclusive = true, + histogram.bins) } + + numBinsHoldingRange / numBinsHoldingEntireRange } /** From 874350905ff41ffdd8dc07548aa3c4f5782a3b35 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 13 Dec 2017 09:10:03 +0200 Subject: [PATCH 410/712] [SPARK-22700][ML] Bucketizer.transform incorrectly drops row containing NaN ## What changes were proposed in this pull request? only drops the rows containing NaN in the input columns ## How was this patch tested? existing tests and added tests Author: Ruifeng Zheng Author: Zheng RuiFeng Closes #19894 from zhengruifeng/bucketizer_nan. --- .../org/apache/spark/ml/feature/Bucketizer.scala | 14 ++++++++------ .../apache/spark/ml/feature/BucketizerSuite.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e07f2a107badb..8299a3e95d822 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -155,10 +155,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop().toDF(), false) + (dataset.na.drop(inputColumns).toDF(), false) } else { (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } @@ -176,11 +182,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String }.withName(s"bucketizer_$idx") } - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { - ($(inputCols).toSeq, $(outputCols).toSeq) - } else { - (Seq($(inputCol)), Seq($(outputCol))) - } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 748dbd1b995d3..d9c97ae8067d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -123,6 +123,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("Bucketizer should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN))) + .toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } + test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) withClue("Invalid NaN split was not caught during Bucketizer initialization") { From 682eb4f2ea152ce1043fbe689ea95318926b91b0 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 12 Dec 2017 23:30:06 -0800 Subject: [PATCH 411/712] [SPARK-22042][SQL] ReorderJoinPredicates can break when child's partitioning is not decided ## What changes were proposed in this pull request? See jira description for the bug : https://issues.apache.org/jira/browse/SPARK-22042 Fix done in this PR is: In `EnsureRequirements`, apply `ReorderJoinPredicates` over the input tree before doing its core logic. Since the tree is transformed bottom-up, we can assure that the children are resolved before doing `ReorderJoinPredicates`. Theoretically this will guarantee to cover all such cases while keeping the code simple. My small grudge is for cosmetic reasons. This PR will look weird given that we don't call rules from other rules (not to my knowledge). I could have moved all the logic for `ReorderJoinPredicates` into `EnsureRequirements` but that will make it a but crowded. I am happy to discuss if there are better options. ## How was this patch tested? Added a new test case Author: Tejas Patil Closes #19257 from tejasapatil/SPARK-22042_ReorderJoinPredicates. --- .../spark/sql/execution/QueryExecution.scala | 2 - .../exchange/EnsureRequirements.scala | 76 ++++++++++++++- .../joins/ReorderJoinPredicates.scala | 94 ------------------- .../spark/sql/sources/BucketedReadSuite.scala | 31 ++++++ 4 files changed, 106 insertions(+), 97 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621399cea..946475a1e9751 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} -import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), - new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4e2ca37bc1a59..82f0b9f5cd060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, + SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -248,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + /** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ + def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { + def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder(expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, + right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } + } + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { @@ -255,6 +328,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) + case operator: SparkPlan => + ensureDistributionAndOrdering(reorderJoinPredicates(operator)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala deleted file mode 100644 index 534d8c5689c27..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -/** - * When the physical operators are created for JOIN, the ordering of join keys is based on order - * in which the join keys appear in the user query. That might not match with the output - * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the - * partitioning of the join nodes' children. - */ -class ReorderJoinPredicates extends Rule[SparkPlan] { - private def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { - - def reorder( - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() - - expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) - (leftKeysBuffer, rightKeysBuffer) - } - - if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftExpressions, leftKeys) - - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(rightExpressions, rightKeys) - - case _ => (leftKeys, rightKeys) - } - } - } else { - (leftKeys, rightKeys) - } - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ab18905e2ddb2..9025859e91066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -602,6 +602,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-22042 ReorderJoinPredicates can break when child's partitioning is not decided") { + withTable("bucketed_table", "table1", "table2") { + df.write.format("parquet").saveAsTable("table1") + df.write.format("parquet").saveAsTable("table2") + df.write.format("parquet").bucketBy(8, "j", "k").saveAsTable("bucketed_table") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + checkAnswer( + sql(""" + |SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k + |FROM ( + | SELECT a.i, a.j, a.k + | FROM bucketed_table a + | JOIN table1 b + | ON a.i = b.i + |) ab + |JOIN table2 c + |ON ab.i = c.i + |""".stripMargin), + sql(""" + |SELECT a.i, a.j, a.k, c.i, c.j, c.k + |FROM bucketed_table a + |JOIN table1 b + |ON a.i = b.i + |JOIN table2 c + |ON a.i = c.i + |""".stripMargin)) + } + } + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") From 7453ab0243dc04db1b586b0e5f588f9cdc9f72dd Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 13 Dec 2017 16:27:29 +0800 Subject: [PATCH 412/712] [SPARK-22745][SQL] read partition stats from Hive ## What changes were proposed in this pull request? Currently Spark can read table stats (e.g. `totalSize, numRows`) from Hive, we can also support to read partition stats from Hive using the same logic. ## How was this patch tested? Added a new test case and modified an existing test case. Author: Zhenhua Wang Author: Zhenhua Wang Closes #19932 from wzhfy/read_hive_partition_stats. --- .../spark/sql/hive/HiveExternalCatalog.scala | 6 +- .../sql/hive/client/HiveClientImpl.scala | 66 +++++++++++-------- .../spark/sql/hive/StatisticsSuite.scala | 32 ++++++--- 3 files changed, 62 insertions(+), 42 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 44e680dbd2f93..632e3e0c4c3f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -668,7 +668,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val schema = restoreTableMetadata(rawTable).schema // convert table statistics to properties so that we can persist them through hive client - var statsProperties = + val statsProperties = if (stats.isDefined) { statsToProperties(stats.get, schema) } else { @@ -1098,14 +1098,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val schema = restoreTableMetadata(rawTable).schema // convert partition statistics to properties so that we can persist them through hive api - val withStatsProps = lowerCasedParts.map(p => { + val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { val statsProperties = statsToProperties(p.stats.get, schema) p.copy(parameters = p.parameters ++ statsProperties) } else { p } - }) + } // Note: Before altering table partitions in Hive, you *must* set the current database // to the one that contains the table of interest. Otherwise you will end up with the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 08eb5c74f06d7..7233944dc96dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -414,32 +414,6 @@ private[hive] class HiveClientImpl( } val comment = properties.get("comment") - // Here we are reading statistics from Hive. - // Note that this statistics could be overridden by Spark's statistics if that's available. - val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) - val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) - val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)) - // TODO: check if this estimate is valid for tables after partition pruning. - // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be - // relatively cheap if parameters for the table are populated into the metastore. - // Currently, only totalSize, rawDataSize, and rowCount are used to build the field `stats` - // TODO: stats should include all the other two fields (`numFiles` and `numPartitions`). - // (see StatsSetupConst in Hive) - val stats = - // When table is external, `totalSize` is always zero, which will influence join strategy. - // So when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, - // return None. - // In Hive, when statistics gathering is disabled, `rawDataSize` and `numRows` is always - // zero after INSERT command. So they are used here only if they are larger than zero. - if (totalSize.isDefined && totalSize.get > 0L) { - Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0))) - } else if (rawDataSize.isDefined && rawDataSize.get > 0) { - Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount.filter(_ > 0))) - } else { - // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? - None - } - CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -478,7 +452,7 @@ private[hive] class HiveClientImpl( // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added // in the function toHiveTable. properties = filteredProperties, - stats = stats, + stats = readHiveStats(properties), comment = comment, // In older versions of Spark(before 2.2.0), we expand the view original text and store // that into `viewExpandedText`, and that should be used in view resolution. So we get @@ -1011,6 +985,11 @@ private[hive] object HiveClientImpl { */ def fromHivePartition(hp: HivePartition): CatalogTablePartition = { val apiPartition = hp.getTPartition + val properties: Map[String, String] = if (hp.getParameters != null) { + hp.getParameters.asScala.toMap + } else { + Map.empty + } CatalogTablePartition( spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), storage = CatalogStorageFormat( @@ -1021,8 +1000,37 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), - parameters = - if (hp.getParameters() != null) hp.getParameters().asScala.toMap else Map.empty) + parameters = properties, + stats = readHiveStats(properties)) + } + + /** + * Reads statistics from Hive. + * Note that this statistics could be overridden by Spark's statistics if that's available. + */ + private def readHiveStats(properties: Map[String, String]): Option[CatalogStatistics] = { + val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) + val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) + val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)) + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. + // Currently, only totalSize, rawDataSize, and rowCount are used to build the field `stats` + // TODO: stats should include all the other two fields (`numFiles` and `numPartitions`). + // (see StatsSetupConst in Hive) + + // When table is external, `totalSize` is always zero, which will influence join strategy. + // So when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, + // return None. + // In Hive, when statistics gathering is disabled, `rawDataSize` and `numRows` is always + // zero after INSERT command. So they are used here only if they are larger than zero. + if (totalSize.isDefined && totalSize.get > 0L) { + Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0))) + } else if (rawDataSize.isDefined && rawDataSize.get > 0) { + Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount.filter(_ > 0))) + } else { + // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? + None + } } // Below is the key of table properties for storing Hive-generated statistics diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 13f06a2503656..3af8af0814bb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -213,6 +213,27 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-22745 - read Hive's statistics for partition") { + val tableName = "hive_stats_part_table" + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2017-01-01') SELECT * FROM src") + var partition = spark.sessionState.catalog + .getPartition(TableIdentifier(tableName), Map("ds" -> "2017-01-01")) + + assert(partition.stats.get.sizeInBytes == 5812) + assert(partition.stats.get.rowCount.isEmpty) + + hiveClient + .runSqlHive(s"ANALYZE TABLE $tableName PARTITION (ds='2017-01-01') COMPUTE STATISTICS") + partition = spark.sessionState.catalog + .getPartition(TableIdentifier(tableName), Map("ds" -> "2017-01-01")) + + assert(partition.stats.get.sizeInBytes == 5812) + assert(partition.stats.get.rowCount == Some(500)) + } + } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { val tableName = "analyzeTable_part" withTable(tableName) { @@ -353,15 +374,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto createPartition("2010-01-02", 11, "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") - sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN") - - assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) - assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) - assert(queryStats("2010-01-02", "10") === None) - assert(queryStats("2010-01-02", "11") === None) - - sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN") - assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000) @@ -631,7 +643,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto """.stripMargin) sql(s"INSERT INTO TABLE $tabName SELECT * FROM src") if (analyzedBySpark) sql(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") - // This is to mimic the scenario in which Hive genrates statistics before we reading it + // This is to mimic the scenario in which Hive generates statistics before we read it if (analyzedByHive) hiveClient.runSqlHive(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") val describeResult1 = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") From 58f7c825ae9dd2e3f548a9b7b4a9704f970dde5b Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 13 Dec 2017 07:52:21 -0600 Subject: [PATCH 413/712] [SPARK-20849][DOC][FOLLOWUP] Document R DecisionTree - Link Classification Example ## What changes were proposed in this pull request? in https://github.com/apache/spark/pull/18067, only the regression example is linked this pr link decision tree classification example to the doc ping felixcheung ## How was this patch tested? local build of docs ![default](https://user-images.githubusercontent.com/7322292/33922857-9b00fdd0-e008-11e7-92c2-85a3de52ea8f.png) Author: Zheng RuiFeng Closes #19963 from zhengruifeng/r_examples. --- docs/ml-classification-regression.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 083df2e405d62..bf979f3c73a52 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -223,6 +223,14 @@ More details on parameters can be found in the [Python API documentation](api/py
    +
    + +Refer to the [R API docs](api/R/spark.decisionTree.html) for more details. + +{% include_example classification r/ml/decisionTree.R %} + +
    + ## Random forest classifier From f6bcd3e53fe55bab38383cce8d0340f6c6191972 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Dec 2017 01:16:44 +0800 Subject: [PATCH 414/712] [SPARK-22767][SQL] use ctx.addReferenceObj in InSet and ScalaUDF ## What changes were proposed in this pull request? We should not operate on `references` directly in `Expression.doGenCode`, instead we should use the high-level API `addReferenceObj`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19962 from cloud-fan/codegen. --- .../sql/catalyst/expressions/ScalaUDF.scala | 76 +++++-------------- .../sql/catalyst/expressions/predicates.scala | 20 ++--- .../catalyst/expressions/ScalaUDFSuite.scala | 3 +- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- 4 files changed, 28 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a3cf7612016b5..388ef42883ad3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -76,11 +76,6 @@ case class ScalaUDF( }.foreach(println) */ - - // Accessors used in genCode - def userDefinedFunc(): AnyRef = function - def getChildren(): Seq[Expression] = children - private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] @@ -981,41 +976,19 @@ case class ScalaUDF( } // scalastyle:on line.size.limit - - private val converterClassName = classOf[Any => Any].getName - private val scalaUDFClassName = classOf[ScalaUDF].getName - private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - - // Generate codes used to convert the arguments to Scala type for user-defined functions - private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = { - val converterTerm = ctx.freshName("converter") - val expressionIdx = ctx.references.size - 1 - (converterTerm, - s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter(((Expression)((($scalaUDFClassName)" + - s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - } - override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { - val scalaUDF = ctx.freshName("scalaUDF") - val scalaUDFRef = ctx.addReferenceObj("scalaUDFRef", this, scalaUDFClassName) - - // Object to convert the returned value of user-defined functions to Catalyst type - val catalystConverterTerm = ctx.freshName("catalystConverter") + val converterClassName = classOf[Any => Any].getName + // The type converters for inputs and the result. + val converters: Array[Any => Any] = children.map { c => + CatalystTypeConverters.createToScalaConverter(c.dataType) + }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType) + val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") + val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) val resultTerm = ctx.freshName("result") - // This must be called before children expressions' codegen - // because ctx.references is used in genCodeForConverter - val converterTerms = children.indices.map(genCodeForConverter(ctx, _)) - - // Initialize user-defined function - val funcClassName = s"scala.Function${children.size}" - - val funcTerm = ctx.freshName("udf") - // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1023,38 +996,31 @@ case class ScalaUDF( // We need to get the boxedType of dataType's javaType here. Because for the dataType // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. - val evalCode = evals.map(_.code).mkString - val (converters, funcArguments) = converterTerms.zipWithIndex.map { - case ((convName, convInit), i) => - val eval = evals(i) - val argTerm = ctx.freshName("arg") - val convert = - s""" - |$convInit - |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value}); - """.stripMargin - (convert, argTerm) + val evalCode = evals.map(_.code).mkString("\n") + val (funcArgs, initArgs) = evals.zipWithIndex.map { case (eval, i) => + val argTerm = ctx.freshName("arg") + val convert = s"$convertersTerm[$i].apply(${eval.value})" + val initArg = s"Object $argTerm = ${eval.isNull} ? null : $convert;" + (argTerm, initArg) }.unzip - val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") + val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" + val resultConverter = s"$convertersTerm[${children.length}]" val callFunc = s""" |${ctx.boxedType(dataType)} $resultTerm = null; - |$scalaUDFClassName $scalaUDF = $scalaUDFRef; |try { - | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); - | $converterClassName $catalystConverterTerm = ($converterClassName) - | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType()); - | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); |} catch (Exception e) { - | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} """.stripMargin ev.copy(code = s""" |$evalCode - |${converters.mkString("\n")} + |${initArgs.mkString("\n")} |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; @@ -1065,7 +1031,7 @@ case class ScalaUDF( """.stripMargin) } - private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType) lazy val udfErrorMessage = { val funcCls = function.getClass.getSimpleName @@ -1081,6 +1047,6 @@ case class ScalaUDF( throw new SparkException(udfErrorMessage, e) } - converter(result) + resultConverter(result) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8eb41addaf689..ac9f56f78eb2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -328,7 +328,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } - @transient private[this] lazy val set = child.dataType match { + @transient lazy val set: Set[Any] = child.dataType match { case _: AtomicType => hset case _: NullType => hset case _ => @@ -336,20 +336,11 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset } - def getSet(): Set[Any] = set - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val setName = classOf[Set[Any]].getName - val InSetName = classOf[InSet].getName + val setTerm = ctx.addReferenceObj("set", set) val childGen = child.genCode(ctx) - ctx.references += this - val setTerm = ctx.freshName("set") - val setNull = if (hasNull) { - s""" - |if (!${ev.value}) { - | ${ev.isNull} = true; - |} - """.stripMargin + val setIsNull = if (hasNull) { + s"${ev.isNull} = !${ev.value};" } else { "" } @@ -359,9 +350,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${ctx.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { - | $setName $setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet(); | ${ev.value} = $setTerm.contains(${childGen.value}); - | $setNull + | $setIsNull |} """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index b0687fe900d3d..70dea4b39d55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,7 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - // ScalaUDF can be very verbose and trigger reduceCodeSize - assert(ctx.mutableStates.forall(_._2.startsWith("globalIsNull"))) + assert(ctx.mutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index d7acd139225cd..478118ed709f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest { val optimizedPlan = OptimizeIn(plan) optimizedPlan match { case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 => + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].set.size == 3 => // pass case _ => fail("Unexpected result for OptimizedIn") } From 8eb5609d8d961e54aa1ed0632f15f5e570fa627a Mon Sep 17 00:00:00 2001 From: zhoukang Date: Wed, 13 Dec 2017 11:47:33 -0800 Subject: [PATCH 415/712] =?UTF-8?q?[SPARK-22754][DEPLOY]=20Check=20whether?= =?UTF-8?q?=20spark.executor.heartbeatInterval=20bigger=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … than spark.network.timeout or not ## What changes were proposed in this pull request? If spark.executor.heartbeatInterval bigger than spark.network.timeout,it will almost always cause exception below. `Job aborted due to stage failure: Task 4763 in stage 3.0 failed 4 times, most recent failure: Lost task 4763.3 in stage 3.0 (TID 22383, executor id: 4761, host: xxx): ExecutorLostFailure (executor 4761 exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 154022 ms` Since many users do not get that point.He will set spark.executor.heartbeatInterval incorrectly. This patch check this case when submit applications. ## How was this patch tested? Test in cluster Author: zhoukang Closes #19942 from caneGuy/zhoukang/check-heartbeat. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 8 ++++++++ .../test/scala/org/apache/spark/SparkConfSuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4b1286d91e8f3..d77303e6fdf8b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -564,6 +564,14 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") + + val executorTimeoutThreshold = getTimeAsSeconds("spark.network.timeout", "120s") + val executorHeartbeatInterval = getTimeAsSeconds("spark.executor.heartbeatInterval", "10s") + // If spark.executor.heartbeatInterval bigger than spark.network.timeout, + // it will almost always cause ExecutorLostFailure. See SPARK-22754. + require(executorTimeoutThreshold > executorHeartbeatInterval, "The value of " + + s"spark.network.timeout=${executorTimeoutThreshold}s must be no less than the value of " + + s"spark.executor.heartbeatInterval=${executorHeartbeatInterval}s.") } /** diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index c771eb4ee3ef5..bff808eb540ac 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -329,6 +329,16 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst conf.validateSettings() } + test("spark.network.timeout should bigger than spark.executor.heartbeatInterval") { + val conf = new SparkConf() + conf.validateSettings() + + conf.set("spark.network.timeout", "5s") + intercept[IllegalArgumentException] { + conf.validateSettings() + } + } + } class Class1 {} From c5a4701acc6972ed7ccb11c506fe718d5503f140 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 13 Dec 2017 11:50:04 -0800 Subject: [PATCH 416/712] Revert "[SPARK-21417][SQL] Infer join conditions using propagated constraints" This reverts commit 6ac57fd0d1c82b834eb4bf0dd57596b92a99d6de. --- .../expressions/EquivalentExpressionMap.scala | 66 ----- .../catalyst/expressions/ExpressionSet.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 1 - .../spark/sql/catalyst/optimizer/joins.scala | 60 ----- .../EquivalentExpressionMapSuite.scala | 56 ----- .../optimizer/EliminateCrossJoinSuite.scala | 238 ------------------ 6 files changed, 423 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala deleted file mode 100644 index cf1614afb1a76..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr - -/** - * A class that allows you to map an expression into a set of equivalent expressions. The keys are - * handled based on their semantic meaning and ignoring cosmetic differences. The values are - * represented as [[ExpressionSet]]s. - * - * The underlying representation of keys depends on the [[Expression.semanticHash]] and - * [[Expression.semanticEquals]] methods. - * - * {{{ - * val map = new EquivalentExpressionMap() - * - * map.put(1 + 2, a) - * map.put(rand(), b) - * - * map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent - * map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent - * map.get(rand()) => Set() // non-deterministic expressions are not equivalent - * }}} - */ -class EquivalentExpressionMap { - - private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, ExpressionSet] - - def put(expression: Expression, equivalentExpression: Expression): Unit = { - val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, ExpressionSet.empty) - equivalenceMap(expression) = equivalentExpressions + equivalentExpression - } - - def get(expression: Expression): Set[Expression] = - equivalenceMap.getOrElse(expression, ExpressionSet.empty) -} - -object EquivalentExpressionMap { - - private implicit class SemanticallyEqualExpr(val expr: Expression) { - override def equals(obj: Any): Boolean = obj match { - case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr) - case _ => false - } - - override def hashCode: Int = expr.semanticHash() - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index e9890837af07d..7e8e7b8cd5f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -27,8 +27,6 @@ object ExpressionSet { expressions.foreach(set.add) set } - - val empty: ExpressionSet = ExpressionSet(Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 577693506ed34..5acadf8cf330e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -87,7 +87,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PushProjectionThroughUnion, ReorderJoin, EliminateOuterJoin, - EliminateCrossJoin, InferFiltersFromConstraints, BooleanSimplification, PushPredicateThroughJoin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 29a3a7f109b80..edbeaf273fd6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec -import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins @@ -153,62 +152,3 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } - -/** - * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. - * - * The optimization is applicable only to CROSS joins. For other join types, adding inferred join - * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN - * node's requirements which otherwise could have. - * - * For instance, given a CROSS join with the constraint 'a = 1' from the left child and the - * constraint 'b = 1' from the right child, this rule infers a new join predicate 'a = b' and - * converts it to an Inner join. - */ -object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = { - if (SQLConf.get.constraintPropagationEnabled) { - eliminateCrossJoin(plan) - } else { - plan - } - } - - private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { - case join @ Join(leftPlan, rightPlan, Cross, None) => - val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) - val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) - val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) - val joinConditionOpt = inferredJoinPredicates.reduceOption(And) - if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join - } - - private def inferJoinPredicates( - leftConstraints: Set[Expression], - rightConstraints: Set[Expression]): mutable.Set[EqualTo] = { - - val equivalentExpressionMap = new EquivalentExpressionMap() - - leftConstraints.foreach { - case EqualTo(attr: Attribute, expr: Expression) => - equivalentExpressionMap.put(expr, attr) - case EqualTo(expr: Expression, attr: Attribute) => - equivalentExpressionMap.put(expr, attr) - case _ => - } - - val joinConditions = mutable.Set.empty[EqualTo] - - rightConstraints.foreach { - case EqualTo(attr: Attribute, expr: Expression) => - joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) - case EqualTo(expr: Expression, attr: Attribute) => - joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) - case _ => - } - - joinConditions - } - -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala deleted file mode 100644 index bad7e17bb6cf2..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ - -class EquivalentExpressionMapSuite extends SparkFunSuite { - - private val onePlusTwo = Literal(1) + Literal(2) - private val twoPlusOne = Literal(2) + Literal(1) - private val rand = Rand(10) - - test("behaviour of the equivalent expression map") { - val equivalentExpressionMap = new EquivalentExpressionMap() - equivalentExpressionMap.put(onePlusTwo, 'a) - equivalentExpressionMap.put(Literal(1) + Literal(3), 'b) - equivalentExpressionMap.put(rand, 'c) - - // 1 + 2 should be equivalent to 2 + 1 - assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) - // non-deterministic expressions should not be equivalent - assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) - - // if the same (key, value) is added several times, the map still returns only one entry - equivalentExpressionMap.put(onePlusTwo, 'a) - equivalentExpressionMap.put(twoPlusOne, 'a) - assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) - - // get several equivalent attributes - equivalentExpressionMap.put(onePlusTwo, 'e) - assertResult(ExpressionSet(Seq('a, 'e)))(equivalentExpressionMap.get(onePlusTwo)) - assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size) - - // several non-deterministic expressions should not be equivalent - equivalentExpressionMap.put(rand, 'd) - assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) - assertResult(0)(equivalentExpressionMap.get(rand).size) - } - -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala deleted file mode 100644 index e04dd28ee36a0..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Not, Rand} -import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED -import org.apache.spark.sql.types.IntegerType - -class EliminateCrossJoinSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Eliminate cross joins", FixedPoint(10), - EliminateCrossJoin, - PushPredicateThroughJoin) :: Nil - } - - val testRelation1 = LocalRelation('a.int, 'b.int) - val testRelation2 = LocalRelation('c.int, 'd.int) - - test("successful elimination of cross joins (1)") { - checkJoinOptimization( - originalFilter = 'a === 1 && 'c === 1 && 'd === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1, - expectedRightRelationFilter = 'c === 1 && 'd === 1, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'c && 'a === 'd)) - } - - test("successful elimination of cross joins (2)") { - checkJoinOptimization( - originalFilter = 'a === 1 && 'b === 2 && 'd === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1 && 'b === 2, - expectedRightRelationFilter = 'd === 1, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'd)) - } - - test("successful elimination of cross joins (3)") { - // PushPredicateThroughJoin will push 'd === 'a into the join condition - // EliminateCrossJoin will NOT apply because the condition will be already present - // therefore, the join type will stay the same (i.e., CROSS) - checkJoinOptimization( - originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1, - expectedRightRelationFilter = Literal(1) === 'd, - expectedJoinType = Cross, - expectedJoinCondition = Some('a === 'd)) - } - - test("successful elimination of cross joins (4)") { - // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically equal - checkJoinOptimization( - originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * Literal(1) === 'c, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === Literal(1) * Literal(2), - expectedRightRelationFilter = Literal(2) * Literal(1) === 'c, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'c)) - } - - test("successful elimination of cross joins (5)") { - checkJoinOptimization( - originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, - expectedRightRelationFilter = 'c === 1, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'c)) - } - - test("successful elimination of cross joins (6)") { - checkJoinOptimization( - originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", IntegerType) && 'd === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === Cast("1", IntegerType), - expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'c)) - } - - test("successful elimination of cross joins (7)") { - // The join condition appears due to PushPredicateThroughJoin - checkJoinOptimization( - originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'b === 10, - expectedRightRelationFilter = 'c === 1, - expectedJoinType = Cross, - expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10)) - } - - test("successful elimination of cross joins (8)") { - checkJoinOptimization( - originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) === 'c, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, - expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c, - expectedJoinType = Inner, - expectedJoinCondition = Some('a === 'c)) - } - - test("inability to detect join conditions when constant propagation is disabled") { - withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { - checkJoinOptimization( - originalFilter = 'a === 1 && 'c === 1 && 'd === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a === 1, - expectedRightRelationFilter = 'c === 1 && 'd === 1, - expectedJoinType = Cross, - expectedJoinCondition = None) - } - } - - test("inability to detect join conditions (1)") { - checkJoinOptimization( - originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = 'a >= 1, - expectedRightRelationFilter = 'c === 1 && 'd >= 1, - expectedJoinType = Cross, - expectedJoinCondition = None) - } - - test("inability to detect join conditions (2)") { - checkJoinOptimization( - originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1), - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = Literal(1) === 'b, - expectedRightRelationFilter = 'c === 1 || 'd === 1, - expectedJoinType = Cross, - expectedJoinCondition = None) - } - - test("inability to detect join conditions (3)") { - checkJoinOptimization( - originalFilter = Literal(1) === 'b && 'c === 1, - originalJoinType = Cross, - originalJoinCondition = Some('c === 'b), - expectedFilter = None, - expectedLeftRelationFilter = Literal(1) === 'b, - expectedRightRelationFilter = 'c === 1, - expectedJoinType = Cross, - expectedJoinCondition = Some('c === 'b)) - } - - test("inability to detect join conditions (4)") { - checkJoinOptimization( - originalFilter = Not('a === 1) && 'd === 1, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = None, - expectedLeftRelationFilter = Not('a === 1), - expectedRightRelationFilter = 'd === 1, - expectedJoinType = Cross, - expectedJoinCondition = None) - } - - test("inability to detect join conditions (5)") { - checkJoinOptimization( - originalFilter = 'a === Rand(10) && 'b === 1 && 'd === Rand(10) && 'c === 3, - originalJoinType = Cross, - originalJoinCondition = None, - expectedFilter = Some('a === Rand(10) && 'd === Rand(10)), - expectedLeftRelationFilter = 'b === 1, - expectedRightRelationFilter = 'c === 3, - expectedJoinType = Cross, - expectedJoinCondition = None) - } - - private def checkJoinOptimization( - originalFilter: Expression, - originalJoinType: JoinType, - originalJoinCondition: Option[Expression], - expectedFilter: Option[Expression], - expectedLeftRelationFilter: Expression, - expectedRightRelationFilter: Expression, - expectedJoinType: JoinType, - expectedJoinCondition: Option[Expression]): Unit = { - - val originalQuery = testRelation1 - .join(testRelation2, originalJoinType, originalJoinCondition) - .where(originalFilter) - val optimizedQuery = Optimize.execute(originalQuery.analyze) - - val left = testRelation1.where(expectedLeftRelationFilter) - val right = testRelation2.where(expectedRightRelationFilter) - val join = left.join(right, expectedJoinType, expectedJoinCondition) - val expectedQuery = expectedFilter.fold(join)(join.where(_)).analyze - - comparePlans(optimizedQuery, expectedQuery) - } -} From 1abcbed678c2bc4f05640db2791fd2d84267d740 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 13 Dec 2017 11:54:22 -0800 Subject: [PATCH 417/712] [SPARK-22763][CORE] SHS: Ignore unknown events and parse through the file ## What changes were proposed in this pull request? While spark code changes, there are new events in event log: #19649 And we used to maintain a whitelist to avoid exceptions: #15663 Currently Spark history server will stop parsing on unknown events or unrecognized properties. We may still see part of the UI data. For better compatibility, we can ignore unknown events and parse through the log file. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #19953 from gengliangwang/ReplayListenerBus. --- .../spark/scheduler/ReplayListenerBus.scala | 37 +++++++++---------- .../spark/scheduler/ReplayListenerSuite.scala | 29 +++++++++++++++ 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 26a6a3effc9ac..c9cd662f5709d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -69,6 +69,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { eventsFilter: ReplayEventsFilter): Unit = { var currentLine: String = null var lineNumber: Int = 0 + val unrecognizedEvents = new scala.collection.mutable.HashSet[String] + val unrecognizedProperties = new scala.collection.mutable.HashSet[String] try { val lineEntries = lines @@ -84,16 +86,22 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) } catch { - case e: ClassNotFoundException if KNOWN_REMOVED_CLASSES.contains(e.getMessage) => - // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. - // It's safe since no place uses them. - logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") - case e: UnrecognizedPropertyException if e.getMessage != null && e.getMessage.startsWith( - "Unrecognized field \"queryStatus\" " + - "(class org.apache.spark.sql.streaming.StreamingQueryListener$") => - // Ignore events generated by Structured Streaming in Spark 2.0.2 - // It's safe since no place uses them. - logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") + case e: ClassNotFoundException => + // Ignore unknown events, parse through the event log file. + // To avoid spamming, warnings are only displayed once for each unknown event. + if (!unrecognizedEvents.contains(e.getMessage)) { + logWarning(s"Drop unrecognized event: ${e.getMessage}") + unrecognizedEvents.add(e.getMessage) + } + logDebug(s"Drop incompatible event log: $currentLine") + case e: UnrecognizedPropertyException => + // Ignore unrecognized properties, parse through the event log file. + // To avoid spamming, warnings are only displayed once for each unrecognized property. + if (!unrecognizedProperties.contains(e.getMessage)) { + logWarning(s"Drop unrecognized property: ${e.getMessage}") + unrecognizedProperties.add(e.getMessage) + } + logDebug(s"Drop incompatible event log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it @@ -125,13 +133,4 @@ private[spark] object ReplayListenerBus { // utility filter that selects all event logs during replay val SELECT_ALL_FILTER: ReplayEventsFilter = { (eventString: String) => true } - - /** - * Classes that were removed. Structured Streaming doesn't use them any more. However, parsing - * old json may fail and we can just ignore these failures. - */ - val KNOWN_REMOVED_CLASSES = Set( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress", - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated" - ) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index d17e3864854a8..73e7b3fe8c1de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -128,6 +128,35 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + test("Replay incompatible event log") { + val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val fstream = fileSystem.create(logFilePath) + val writer = new PrintWriter(fstream) + val applicationStart = SparkListenerApplicationStart("Incompatible App", None, + 125L, "UserUsingIncompatibleVersion", None) + val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println + writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) + writer.println("""{"Event":"UnrecognizedEventOnlyForTest","Timestamp":1477593059313}""") + writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println + writer.close() + + val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) + val logData = fileSystem.open(logFilePath) + val eventMonster = new EventMonster(conf) + try { + val replayer = new ReplayListenerBus() + replayer.addListener(eventMonster) + replayer.replay(logData, logFilePath.toString) + } finally { + logData.close() + } + assert(eventMonster.loggedEvents.size === 2) + assert(eventMonster.loggedEvents(0) === JsonProtocol.sparkEventToJson(applicationStart)) + assert(eventMonster.loggedEvents(1) === JsonProtocol.sparkEventToJson(applicationEnd)) + } + // This assumes the correctness of EventLoggingListener test("End-to-end replay") { testApplicationReplay() From 0bdb4e516c425ea7bf941106ac6449b5a0a289e3 Mon Sep 17 00:00:00 2001 From: German Schiavon Date: Wed, 13 Dec 2017 13:37:25 -0800 Subject: [PATCH 418/712] [SPARK-22574][MESOS][SUBMIT] Check submission request parameters ## What changes were proposed in this pull request? PR closed with all the comments -> https://github.com/apache/spark/pull/19793 It solves the problem when submitting a wrong CreateSubmissionRequest to Spark Dispatcher was causing a bad state of Dispatcher and making it inactive as a Mesos framework. https://issues.apache.org/jira/browse/SPARK-22574 ## How was this patch tested? All spark test passed successfully. It was tested sending a wrong request (without appArgs) before and after the change. The point is easy, check if the value is null before being accessed. This was before the change, leaving the dispatcher inactive: ``` Exception in thread "Thread-22" java.lang.NullPointerException at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.getDriverCommandValue(MesosClusterScheduler.scala:444) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.buildDriverCommand(MesosClusterScheduler.scala:451) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.org$apache$spark$scheduler$cluster$mesos$MesosClusterScheduler$$createTaskInfo(MesosClusterScheduler.scala:538) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler$$anonfun$scheduleTasks$1.apply(MesosClusterScheduler.scala:570) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler$$anonfun$scheduleTasks$1.apply(MesosClusterScheduler.scala:555) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.scheduleTasks(MesosClusterScheduler.scala:555) at org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler.resourceOffers(MesosClusterScheduler.scala:621) ``` And after: ``` "message" : "Malformed request: org.apache.spark.deploy.rest.SubmitRestProtocolException: Validation of message CreateSubmissionRequest failed!\n\torg.apache.spark.deploy.rest.SubmitRestProtocolMessage.validate(SubmitRestProtocolMessage.scala:70)\n\torg.apache.spark.deploy.rest.SubmitRequestServlet.doPost(RestSubmissionServer.scala:272)\n\tjavax.servlet.http.HttpServlet.service(HttpServlet.java:707)\n\tjavax.servlet.http.HttpServlet.service(HttpServlet.java:790)\n\torg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:845)\n\torg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:583)\n\torg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\n\torg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:511)\n\torg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\n\torg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\n\torg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\n\torg.spark_project.jetty.server.Server.handle(Server.java:524)\n\torg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:319)\n\torg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:253)\n\torg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:273)\n\torg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:95)\n\torg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\n\torg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\n\torg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\n\torg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\n\tjava.lang.Thread.run(Thread.java:745)" ``` Author: German Schiavon Closes #19966 from Gschiavon/fix-submission-request. --- .../deploy/rest/SubmitRestProtocolRequest.scala | 2 ++ .../spark/deploy/rest/SubmitRestProtocolSuite.scala | 2 ++ .../spark/deploy/rest/mesos/MesosRestServer.scala | 13 +++++++++---- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 0d50a768942ed..86ddf954ca128 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -46,6 +46,8 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") + assertFieldIsSet(appArgs, "appArgs") + assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean("spark.driver.supervise") assertPropertyIsNumeric("spark.driver.cores") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 725b8848bc052..75c50af23c66a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -86,6 +86,8 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" + message.appArgs = Array("two slices") + message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index ff60b88c6d533..68f6921153d89 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -77,10 +77,17 @@ private[mesos] class MesosSubmitRequestServlet( private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { // Required fields, including the main class because python is not yet supported val appResource = Option(request.appResource).getOrElse { - throw new SubmitRestMissingFieldException("Application jar is missing.") + throw new SubmitRestMissingFieldException("Application jar 'appResource' is missing.") } val mainClass = Option(request.mainClass).getOrElse { - throw new SubmitRestMissingFieldException("Main class is missing.") + throw new SubmitRestMissingFieldException("Main class 'mainClass' is missing.") + } + val appArgs = Option(request.appArgs).getOrElse { + throw new SubmitRestMissingFieldException("Application arguments 'appArgs' are missing.") + } + val environmentVariables = Option(request.environmentVariables).getOrElse { + throw new SubmitRestMissingFieldException("Environment variables 'environmentVariables' " + + "are missing.") } // Optional fields @@ -91,8 +98,6 @@ private[mesos] class MesosSubmitRequestServlet( val superviseDriver = sparkProperties.get("spark.driver.supervise") val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") - val appArgs = request.appArgs - val environmentVariables = request.environmentVariables val name = request.sparkProperties.getOrElse("spark.app.name", mainClass) // Construct driver description From ba0e79f57caa279773fb014b7883ee5d69dd0a68 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Dec 2017 13:54:16 -0800 Subject: [PATCH 419/712] [SPARK-22772][SQL] Use splitExpressionsWithCurrentInputs to split codes in elt ## What changes were proposed in this pull request? In SPARK-22550 which fixes 64KB JVM bytecode limit problem with elt, `buildCodeBlocks` is used to split codes. However, we should use `splitExpressionsWithCurrentInputs` because it considers both normal and wholestage codgen (it is not supported yet, so it simply doesn't split the codes). ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #19964 from viirya/SPARK-22772. --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/stringExpressions.scala | 81 ++++++++++--------- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 257c3f10fa08b..b1d931117f99b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -878,7 +878,7 @@ class CodegenContext { * * @param expressions the codes to evaluate expressions. */ - def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { + private def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() var length = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 47f0b5741f67f..8c4d2fd686be5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -289,53 +289,56 @@ case class Elt(children: Seq[Expression]) val index = indexExpr.genCode(ctx) val strings = stringExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") + val indexMatched = ctx.freshName("eltIndexMatched") + val stringVal = ctx.freshName("stringVal") + ctx.addMutableState(ctx.javaType(dataType), stringVal) + val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" - case ${index + 1}: - ${eval.code} - $stringVal = ${eval.isNull} ? null : ${eval.value}; - break; - """ + |if ($indexVal == ${index + 1}) { + | ${eval.code} + | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $indexMatched = true; + | continue; + |} + """.stripMargin } - val cases = ctx.buildCodeBlocks(assignStringValue) - val codes = if (cases.length == 1) { - s""" - UTF8String $stringVal = null; - switch ($indexVal) { - ${cases.head} - } - """ - } else { - var prevFunc = "null" - for (c <- cases.reverse) { - val funcName = ctx.freshName("eltFunc") - val funcBody = s""" - private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) { - UTF8String $stringVal = null; - switch ($indexVal) { - $c - default: - return $prevFunc; - } - return $stringVal; - } - """ - val fullFuncName = ctx.addNewFunction(funcName, funcBody) - prevFunc = s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)" - } - s"UTF8String $stringVal = $prevFunc;" - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignStringValue, + funcName = "eltFunc", + extraArguments = ("int", indexVal) :: Nil, + returnType = ctx.JAVA_BOOLEAN, + makeSplitFunction = body => + s""" + |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |do { + | $body + |} while (false); + |return $indexMatched; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$indexMatched = $funcCall; + |if ($indexMatched) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy( s""" - ${index.code} - final int $indexVal = ${index.value}; - $codes - UTF8String ${ev.value} = $stringVal; - final boolean ${ev.isNull} = ${ev.value} == null; - """) + |${index.code} + |final int $indexVal = ${index.value}; + |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |$stringVal = null; + |do { + | $codes + |} while (false); + |final UTF8String ${ev.value} = $stringVal; + |final boolean ${ev.isNull} = ${ev.value} == null; + """.stripMargin) } } From a83e8e6c223df8b819335cbabbfff9956942f2ad Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 13 Dec 2017 16:06:16 -0600 Subject: [PATCH 420/712] [SPARK-22764][CORE] Fix flakiness in SparkContextSuite. Use a semaphore to synchronize the tasks with the listener code that is trying to cancel the job or stage, so that the listener won't try to cancel a job or stage that has already finished. Author: Marcelo Vanzin Closes #19956 from vanzin/SPARK-22764. --- .../org/apache/spark/SparkContextSuite.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 37fcc93c62fa8..b30bd74812b36 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.TimeUnit +import java.util.concurrent.{Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -499,6 +499,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) val REASON = "You shall not pass" + val slices = 10 val listener = new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { @@ -508,6 +509,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } sc.cancelStage(taskStart.stageId, REASON) SparkContextSuite.cancelStage = false + SparkContextSuite.semaphore.release(slices) } } @@ -518,21 +520,25 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } sc.cancelJob(jobStart.jobId, REASON) SparkContextSuite.cancelJob = false + SparkContextSuite.semaphore.release(slices) } } } sc.addSparkListener(listener) for (cancelWhat <- Seq("stage", "job")) { + SparkContextSuite.semaphore.drainPermits() SparkContextSuite.isTaskStarted = false SparkContextSuite.cancelStage = (cancelWhat == "stage") SparkContextSuite.cancelJob = (cancelWhat == "job") val ex = intercept[SparkException] { - sc.range(0, 10000L).mapPartitions { x => - org.apache.spark.SparkContextSuite.isTaskStarted = true + sc.range(0, 10000L, numSlices = slices).mapPartitions { x => + SparkContextSuite.isTaskStarted = true + // Block waiting for the listener to cancel the stage or job. + SparkContextSuite.semaphore.acquire() x - }.cartesian(sc.range(0, 10L))count() + }.count() } ex.getCause() match { @@ -636,4 +642,5 @@ object SparkContextSuite { @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false + val semaphore = new Semaphore(0) } From ef92999653f0e2a47752379a867647445d849aab Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 13 Dec 2017 15:55:16 -0800 Subject: [PATCH 421/712] [SPARK-22600][SQL][FOLLOW-UP] Fix a compilation error in TPCDS q75/q77 ## What changes were proposed in this pull request? This pr fixed a compilation error of TPCDS `q75`/`q77` caused by #19813; ``` java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 371, Column 16: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 371, Column 16: Expression "bhj_matched" is not an rvalue at com.google.common.util.concurrent.AbstractFuture$Sync.getValue(AbstractFuture.java:306) at com.google.common.util.concurrent.AbstractFuture$Sync.get(AbstractFuture.java:293) at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:116) at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:135) ``` ## How was this patch tested? Manually checked `q75`/`q77` can be properly compiled Author: Takeshi Yamamuro Closes #19969 from maropu/SPARK-22600-FOLLOWUP. --- .../sql/catalyst/expressions/codegen/ExpressionCodegen.scala | 1 - .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index a2dda48e951d1..807cb94734b2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.DataType /** * Defines util methods used in expression code generation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index c96ed6ef41016..634014af16e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, isNull, value, inputRow = matched) } } } From bc7e4a90c0c91970a94aa385971daac48db6264e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Dec 2017 11:21:34 +0800 Subject: [PATCH 422/712] Revert "[SPARK-22600][SQL][FOLLOW-UP] Fix a compilation error in TPCDS q75/q77" This reverts commit ef92999653f0e2a47752379a867647445d849aab. --- .../sql/catalyst/expressions/codegen/ExpressionCodegen.scala | 1 + .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index 807cb94734b2a..a2dda48e951d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType /** * Defines util methods used in expression code generation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 634014af16e8b..c96ed6ef41016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value, inputRow = matched) + ExprCode(code, isNull, value) } } } From 2a29a60da32a5ccd04cc7c5c22c4075b159389e3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Dec 2017 11:22:23 +0800 Subject: [PATCH 423/712] Revert "[SPARK-22600][SQL] Fix 64kb limit for deeply nested expressions under wholestage codegen" This reverts commit c7d0148615c921dca782ee3785b5d0cd59e42262. --- .../sql/catalyst/expressions/Expression.scala | 37 +-- .../expressions/codegen/CodeGenerator.scala | 44 +-- .../codegen/ExpressionCodegen.scala | 269 ------------------ .../codegen/ExpressionCodegenSuite.scala | 220 -------------- .../sql/execution/ColumnarBatchScan.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 23 +- 6 files changed, 13 insertions(+), 585 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 329ea5d421509..743782a6453e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,12 +105,6 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) - eval.isNull = if (this.nullable) eval.isNull else "false" - - // Records current input row and variables of this expression. - eval.inputRow = ctx.INPUT_ROW - eval.inputVars = findInputVars(ctx, eval) - reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -121,29 +115,9 @@ abstract class Expression extends TreeNode[Expression] { } } - /** - * Returns the input variables to this expression. - */ - private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { - if (ctx.currentVars != null) { - this.collect { - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => - ExprInputVar(exprCode = ctx.currentVars(ordinal), - dataType = b.dataType, nullable = b.nullable) - } - } else { - Seq.empty - } - } - - /** - * In order to prevent 64kb compile error, reducing the size of generated codes by - * separating it into a function if the size exceeds a threshold. - */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) - - if (eval.code.trim.length > 1024 && funcParams.isDefined) { + // TODO: support whole stage codegen too + if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -158,12 +132,9 @@ abstract class Expression extends TreeNode[Expression] { val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) - val callParams = funcParams.map(_._1.mkString(", ")).get - val declParams = funcParams.map(_._2.mkString(", ")).get - val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName($declParams) { + |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -171,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName($callParams);" + eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b1d931117f99b..3a03a65e1af92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,24 +55,8 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. - * @param inputRow A term that holds the input row name when generating this code. - * @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code. */ -case class ExprCode( - var code: String, - var isNull: String, - var value: String, - var inputRow: String = null, - var inputVars: Seq[ExprInputVar] = Seq.empty) - -/** - * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. - * - * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. - * @param dataType The data type of the input variable. - * @param nullable Whether the input variable can be null or not. - */ -case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean) +case class ExprCode(var code: String, var isNull: String, var value: String) /** * State used for subexpression elimination. @@ -1028,25 +1012,16 @@ class CodegenContext { commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") - val isNull = if (expr.nullable) { - s"${fnName}IsNull" - } else { - "" - } + val isNull = s"${fnName}IsNull" val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) - val assignIsNull = if (expr.nullable) { - s"$isNull = ${eval.isNull};" - } else { - "" - } val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $assignIsNull + | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} """.stripMargin @@ -1064,17 +1039,12 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - if (expr.nullable) { - addMutableState(JAVA_BOOLEAN, isNull) - } - addMutableState(javaType(expr.dataType), value) + addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), value, + s"$value = ${defaultValue(expr.dataType)};") subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = if (expr.nullable) { - SubExprEliminationState(isNull, value) - } else { - SubExprEliminationState("false", value) - } + val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala deleted file mode 100644 index a2dda48e951d1..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.DataType - -/** - * Defines util methods used in expression code generation. - */ -object ExpressionCodegen { - - /** - * Given an expression, returns the all necessary parameters to evaluate it, so the generated - * code of this expression can be split in a function. - * The 1st string in returned tuple is the parameter strings used to call the function. - * The 2nd string in returned tuple is the parameter strings used to declare the function. - * - * Returns `None` if it can't produce valid parameters. - * - * Params to include: - * 1. Evaluated columns referred by this, children or deferred expressions. - * 2. Rows referred by this, children or deferred expressions. - * 3. Eliminated subexpressions referred by children expressions. - */ - def getExpressionInputParams( - ctx: CodegenContext, - expr: Expression): Option[(Seq[String], Seq[String])] = { - val subExprs = getSubExprInChildren(ctx, expr) - val subExprCodes = getSubExprCodes(ctx, subExprs) - val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) => - ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable) - } - val paramsFromSubExprs = prepareFunctionParams(ctx, subVars) - - val inputVars = getInputVarsForChildren(ctx, expr) - val paramsFromColumns = prepareFunctionParams(ctx, inputVars) - - val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) - val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => - (row, s"InternalRow $row") - } - - val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length - // Maximum allowed parameter number for Java's method descriptor. - if (paramsLength > 255) { - None - } else { - val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip - val callParams = allParams._1.distinct - val declParams = allParams._2.distinct - Some((callParams, declParams)) - } - } - - /** - * Returns the eliminated subexpressions in the children expressions. - */ - def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { - expr.children.flatMap { child => - child.collect { - case e if ctx.subExprEliminationExprs.contains(e) => e - } - }.distinct - } - - /** - * A small helper function to return `ExprCode`s that represent subexpressions. - */ - def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { - subExprs.map { subExpr => - val state = ctx.subExprEliminationExprs(subExpr) - ExprCode(code = "", value = state.value, isNull = state.isNull) - } - } - - /** - * Retrieves previous input rows referred by children and deferred expressions. - */ - def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { - expr.children.flatMap(getInputRows(ctx, _)).distinct - } - - /** - * Given a child expression, retrieves previous input rows referred by it or deferred expressions - * which are needed to evaluate it. - */ - def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { - child.flatMap { - // An expression directly evaluates on current input row. - case BoundReference(ordinal, _, _) if ctx.currentVars == null || - ctx.currentVars(ordinal) == null => - Seq(ctx.INPUT_ROW) - - // An expression which is not evaluated yet. Tracks down to find input rows. - case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) => - trackDownRow(ctx, ctx.currentVars(ordinal)) - - case _ => Seq.empty - }.distinct - } - - /** - * Tracks down input rows referred by the generated code snippet. - */ - def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { - val exprCodes = mutable.Queue[ExprCode](exprCode) - val inputRows = mutable.ArrayBuffer.empty[String] - - while (exprCodes.nonEmpty) { - val curExprCode = exprCodes.dequeue() - if (curExprCode.inputRow != null) { - inputRows += curExprCode.inputRow - } - curExprCode.inputVars.foreach { inputVar => - if (!isEvaluated(inputVar.exprCode)) { - exprCodes.enqueue(inputVar.exprCode) - } - } - } - inputRows - } - - /** - * Retrieves previously evaluated columns referred by children and deferred expressions. - * Returned tuple contains the list of expressions and the list of generated codes. - */ - def getInputVarsForChildren( - ctx: CodegenContext, - expr: Expression): Seq[ExprInputVar] = { - expr.children.flatMap(getInputVars(ctx, _)).distinct - } - - /** - * Given a child expression, retrieves previously evaluated columns referred by it or - * deferred expressions which are needed to evaluate it. - */ - def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = { - if (ctx.currentVars == null) { - return Seq.empty - } - - child.flatMap { - // An evaluated variable. - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && - isEvaluated(ctx.currentVars(ordinal)) => - Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) - - // An input variable which is not evaluated yet. Tracks down to find any evaluated variables - // in the expression path. - // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to - // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so - // to include them into parameters, if not, we track down further. - case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => - trackDownVar(ctx, ctx.currentVars(ordinal)) - - case _ => Seq.empty - }.distinct - } - - /** - * Tracks down previously evaluated columns referred by the generated code snippet. - */ - def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = { - val exprCodes = mutable.Queue[ExprCode](exprCode) - val inputVars = mutable.ArrayBuffer.empty[ExprInputVar] - - while (exprCodes.nonEmpty) { - exprCodes.dequeue().inputVars.foreach { inputVar => - if (isEvaluated(inputVar.exprCode)) { - inputVars += inputVar - } else { - exprCodes.enqueue(inputVar.exprCode) - } - } - } - inputVars - } - - /** - * Helper function to calculate the size of an expression as function parameter. - */ - def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = { - (if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match { - case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2 - case _ => 1 - } - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length of 255 or less. `this` contributes one unit and a parameter of type long or double - * contributes two units. - */ - def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = { - // Initial value is 1 for `this`. - 1 + inputs.map(calculateParamLength(ctx, _)).sum - } - - /** - * Given the lists of input attributes and variables to this expression, returns the strings of - * funtion parameters. The first is the variable names used to call the function, the second is - * the parameters used to declare the function in generated code. - */ - def prepareFunctionParams( - ctx: CodegenContext, - inputVars: Seq[ExprInputVar]): Seq[(String, String)] = { - inputVars.flatMap { inputVar => - val params = mutable.ArrayBuffer.empty[(String, String)] - val ev = inputVar.exprCode - - // Only include the expression value if it is not a literal. - if (!isLiteral(ev)) { - val argType = ctx.javaType(inputVar.dataType) - params += ((ev.value, s"$argType ${ev.value}")) - } - - // If it is a nullable expression and `isNull` is not a literal. - if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") { - params += ((ev.isNull, s"boolean ${ev.isNull}")) - } - - params - }.distinct - } - - /** - * Only applied to the `ExprCode` in `ctx.currentVars`. - * Returns true if this value is a literal. - */ - def isLiteral(exprCode: ExprCode): Boolean = { - assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.") - - if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") { - true - } else { - // The valid characters for the first character of a Java variable is [a-zA-Z_$]. - exprCode.value.head match { - case v if v >= 'a' && v <= 'z' => false - case v if v >= 'A' && v <= 'Z' => false - case '_' | '$' => false - case _ => true - } - } - } - - /** - * Only applied to the `ExprCode` in `ctx.currentVars`. - * The code is emptied after evaluation. - */ - def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == "" -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala deleted file mode 100644 index 39d58cabff228..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.IntegerType - -class ExpressionCodegenSuite extends SparkFunSuite { - - test("Returns eliminated subexpressions for expression") { - val ctx = new CodegenContext() - val subExpr = Add(Literal(1), Literal(2)) - val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4))) - - ctx.generateExpressions(exprs, doSubexpressionElimination = true) - val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0)) - assert(subexpressions.length == 1 && subexpressions(0) == subExpr) - } - - test("Gets parameters for subexpressions") { - val ctx = new CodegenContext() - val subExprs = Seq( - Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable - Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable - - ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1")) - ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) - - val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) - val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) => - ExprInputVar(exprCode, expr.dataType, expr.nullable) - } - val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars) - assert(params.length == 3) - assert(params(0) == Tuple2("value1", "int value1")) - assert(params(1) == Tuple2("value2", "int value2")) - assert(params(2) == Tuple2("isNull2", "boolean isNull2")) - } - - test("Returns input variables for expression: current variables") { - val ctx = new CodegenContext() - val currentVars = Seq( - ExprCode("", isNull = "false", value = "value1"), // evaluated - ExprCode("", isNull = "isNull2", value = "value2"), // evaluated - ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated - ctx.currentVars = currentVars - ctx.INPUT_ROW = null - - val expr = If(Literal(false), - Add(BoundReference(0, IntegerType, nullable = false), - BoundReference(1, IntegerType, nullable = true)), - BoundReference(2, IntegerType, nullable = true)) - - val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) - // Only two evaluated variables included. - assert(inputVars.length == 2) - assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false) - assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true) - assert(inputVars(0).exprCode == currentVars(0)) - assert(inputVars(1).exprCode == currentVars(1)) - - val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) - assert(params.length == 3) - assert(params(0) == Tuple2("value1", "int value1")) - assert(params(1) == Tuple2("value2", "int value2")) - assert(params(2) == Tuple2("isNull2", "boolean isNull2")) - } - - test("Returns input variables for expression: deferred variables") { - val ctx = new CodegenContext() - - // The referred column is not evaluated yet. But it depends on an evaluated column from - // other operator. - val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) - - // currentVars(0) depends on this evaluated column. - currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"), - dataType = IntegerType, nullable = true)) - ctx.currentVars = currentVars - ctx.INPUT_ROW = null - - val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) - val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) - assert(inputVars.length == 1) - assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true) - - val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) - assert(params.length == 2) - assert(params(0) == Tuple2("value2", "int value2")) - assert(params(1) == Tuple2("isNull2", "boolean isNull2")) - } - - test("Returns input rows for expression") { - val ctx = new CodegenContext() - ctx.currentVars = null - ctx.INPUT_ROW = "i" - - val expr = Add(BoundReference(0, IntegerType, nullable = false), - BoundReference(1, IntegerType, nullable = true)) - val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) - assert(inputRows.length == 1) - assert(inputRows(0) == "i") - } - - test("Returns input rows for expression: deferred expression") { - val ctx = new CodegenContext() - - // The referred column is not evaluated yet. But it depends on an input row from - // other operator. - val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) - currentVars(0).inputRow = "inputadaptor_row1" - ctx.currentVars = currentVars - ctx.INPUT_ROW = null - - val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) - val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) - assert(inputRows.length == 1) - assert(inputRows(0) == "inputadaptor_row1") - } - - test("Returns both input rows and variables for expression") { - val ctx = new CodegenContext() - // 5 input variables in currentVars: - // 1 evaluated variable (value1). - // 3 not evaluated variables. - // value2 depends on an evaluated column from other operator. - // value3 depends on an input row from other operator. - // value4 depends on a not evaluated yet column from other operator. - // 1 null indicating to use input row "i". - val currentVars = Seq( - ExprCode("", isNull = "false", value = "value1"), - ExprCode("fake code;", isNull = "isNull2", value = "value2"), - ExprCode("fake code;", isNull = "isNull3", value = "value3"), - ExprCode("fake code;", isNull = "isNull4", value = "value4"), - null) - // value2 depends on this evaluated column. - currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"), - dataType = IntegerType, nullable = true)) - // value3 depends on an input row "inputadaptor_row1". - currentVars(2).inputRow = "inputadaptor_row1" - // value4 depends on another not evaluated yet column. - currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;", - isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true)) - ctx.currentVars = currentVars - ctx.INPUT_ROW = "i" - - // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] } - val expr = If(Literal(false), - Add(BoundReference(0, IntegerType, nullable = false), - BoundReference(1, IntegerType, nullable = true)), - Add(Add(BoundReference(2, IntegerType, nullable = true), - BoundReference(3, IntegerType, nullable = true)), - BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i". - - // input rows: "i", "inputadaptor_row1". - val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) - assert(inputRows.length == 2) - assert(inputRows(0) == "inputadaptor_row1") - assert(inputRows(1) == "i") - - // input variables: value1 and value5 - val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) - assert(inputVars.length == 2) - - // value1 has inlined isNull "false", so don't need to include it in the params. - val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) - assert(inputVarParams.length == 3) - assert(inputVarParams(0) == Tuple2("value1", "int value1")) - assert(inputVarParams(1) == Tuple2("value5", "int value5")) - assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5")) - } - - test("isLiteral: literals") { - val literals = Seq( - ExprCode("", "", "true"), - ExprCode("", "", "false"), - ExprCode("", "", "1"), - ExprCode("", "", "-1"), - ExprCode("", "", "1L"), - ExprCode("", "", "-1L"), - ExprCode("", "", "1.0f"), - ExprCode("", "", "-1.0f"), - ExprCode("", "", "0.1f"), - ExprCode("", "", "-0.1f"), - ExprCode("", "", """"string""""), - ExprCode("", "", "(byte)-1"), - ExprCode("", "", "(short)-1"), - ExprCode("", "", "null")) - - literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true)) - } - - test("isLiteral: non literals") { - val variables = Seq( - ExprCode("", "", "var1"), - ExprCode("", "", "_var2"), - ExprCode("", "", "$var3"), - ExprCode("", "", "v1a2r3"), - ExprCode("", "", "_1v2a3r"), - ExprCode("", "", "$1v2a3r")) - - variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 05186c4472566..a9bfb634fbdea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -108,10 +108,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null - // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it. - // So making it as global variable. val rowidx = ctx.freshName("rowIdx") - ctx.addMutableState(ctx.JAVA_INT, rowidx) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } @@ -131,7 +128,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | $rowidx = $idx + $localIdx; + | int $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 1281169b607c9..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -237,24 +236,4 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") { - import testImplicits._ - withTempPath { dir => - val path = dir.getCanonicalPath - val df = Seq(("abc", 1)).toDF("key", "int") - df.write.parquet(path) - - var strExpr: Expression = col("key").expr - for (_ <- 1 to 150) { - strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8")) - } - val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) - - val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) - val plan = df2.queryExecution.executedPlan - assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) - df2.collect() - } - } } From 1e44dd004425040912f2cf16362d2c13f12e1689 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Dec 2017 21:19:14 -0800 Subject: [PATCH 424/712] [SPARK-3181][ML] Implement huber loss for LinearRegression. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? MLlib ```LinearRegression``` supports _huber_ loss addition to _leastSquares_ loss. The huber loss objective function is: ![image](https://user-images.githubusercontent.com/1962026/29554124-9544d198-8750-11e7-8afa-33579ec419d5.png) Refer Eq.(6) and Eq.(8) in [A robust hybrid of lasso and ridge regression](http://statweb.stanford.edu/~owen/reports/hhu.pdf). This objective is jointly convex as a function of (w, σ) ∈ R × (0,∞), we can use L-BFGS-B to solve it. The current implementation is a straight forward porting for Python scikit-learn [```HuberRegressor```](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.HuberRegressor.html). There are some differences: * We use mean loss (```lossSum/weightSum```), but sklearn uses total loss (```lossSum```). * We multiply the loss function and L2 regularization by 1/2. It does not affect the result if we multiply the whole formula by a factor, we just keep consistent with _leastSquares_ loss. So if fitting w/o regularization, MLlib and sklearn produce the same output. If fitting w/ regularization, MLlib should set ```regParam``` divide by the number of instances to match the output of sklearn. ## How was this patch tested? Unit tests. Author: Yanbo Liang Closes #19020 from yanboliang/spark-3181. --- .../ml/optim/aggregator/HuberAggregator.scala | 150 +++++++++ .../ml/param/shared/SharedParamsCodeGen.scala | 3 +- .../spark/ml/param/shared/sharedParams.scala | 17 + .../ml/regression/LinearRegression.scala | 299 ++++++++++++++---- .../aggregator/HuberAggregatorSuite.scala | 170 ++++++++++ .../ml/regression/LinearRegressionSuite.scala | 244 +++++++++++++- project/MimaExcludes.scala | 5 + 7 files changed, 823 insertions(+), 65 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala new file mode 100644 index 0000000000000..13f64d2d50424 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.Vector + +/** + * HuberAggregator computes the gradient and loss for a huber loss function, + * as used in robust regression for samples in sparse or dense vector in an online fashion. + * + * The huber loss function based on: + * Art B. Owen (2006), + * A robust hybrid of lasso and ridge regression. + * + * Two HuberAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * The huber loss function is given by + * + *
    + * $$ + * \begin{align} + * \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma + + * H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + \frac{1}{2}\lambda {||w||_2}^2} + * \end{align} + * $$ + *
    + * + * where + * + *
    + * $$ + * \begin{align} + * H_m(z) = \begin{cases} + * z^2, & \text {if } |z| < \epsilon, \\ + * 2\epsilon|z| - \epsilon^2, & \text{otherwise} + * \end{cases} + * \end{align} + * $$ + *
    + * + * It is advised to set the parameter $\epsilon$ to 1.35 to achieve 95% statistical efficiency + * for normally distributed data. Please refer to chapter 2 of + * + * A robust hybrid of lasso and ridge regression for more detail. + * + * @param fitIntercept Whether to fit an intercept term. + * @param epsilon The shape parameter to control the amount of robustness. + * @param bcFeaturesStd The broadcast standard deviation values of the features. + * @param bcParameters including three parts: the regression coefficients corresponding + * to the features, the intercept (if fitIntercept is ture) + * and the scale parameter (sigma). + */ +private[ml] class HuberAggregator( + fitIntercept: Boolean, + epsilon: Double, + bcFeaturesStd: Broadcast[Array[Double]])(bcParameters: Broadcast[Vector]) + extends DifferentiableLossAggregator[Instance, HuberAggregator] { + + protected override val dim: Int = bcParameters.value.size + private val numFeatures: Int = if (fitIntercept) dim - 2 else dim - 1 + private val sigma: Double = bcParameters.value(dim - 1) + private val intercept: Double = if (fitIntercept) { + bcParameters.value(dim - 2) + } else { + 0.0 + } + + /** + * Add a new training instance to this HuberAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This HuberAggregator object. + */ + def add(instance: Instance): HuberAggregator = { + instance match { case Instance(label, weight, features) => + require(numFeatures == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $numFeatures but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + val localFeaturesStd = bcFeaturesStd.value + val localCoefficients = bcParameters.value.toArray.slice(0, numFeatures) + val localGradientSumArray = gradientSumArray + + val margin = { + var sum = 0.0 + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += localCoefficients(index) * (value / localFeaturesStd(index)) + } + } + if (fitIntercept) sum += intercept + sum + } + val linearLoss = label - margin + + if (math.abs(linearLoss) <= sigma * epsilon) { + lossSum += 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma) + val linearLossDivSigma = linearLoss / sigma + + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += + -1.0 * weight * linearLossDivSigma * (value / localFeaturesStd(index)) + } + } + if (fitIntercept) { + localGradientSumArray(dim - 2) += -1.0 * weight * linearLossDivSigma + } + localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - math.pow(linearLossDivSigma, 2.0)) + } else { + val sign = if (linearLoss >= 0) -1.0 else 1.0 + lossSum += 0.5 * weight * + (sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon) + + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += + weight * sign * epsilon * (value / localFeaturesStd(index)) + } + } + if (fitIntercept) { + localGradientSumArray(dim - 2) += weight * sign * epsilon + } + localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - epsilon * epsilon) + } + + weightSum += weight + this + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a267bbc874322..a5d57a15317e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -88,7 +88,8 @@ private[shared] object SharedParamsCodeGen { "during tuning. If set to false, then only the single best sub-model will be available " + "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", - Some("false"), isExpertParam = true) + Some("false"), isExpertParam = true), + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0004f085d10f2..13425dacc9f18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -487,4 +487,21 @@ trait HasCollectSubModels extends Params { /** @group expertGetParam */ final def getCollectSubModels: Boolean = $(collectSubModels) } + +/** + * Trait for shared param loss. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasLoss extends Params { + + /** + * Param for the loss function to be optimized. + * @group param + */ + val loss: Param[String] = new Param[String](this, "loss", "the loss function to be optimized") + + /** @group getParam */ + final def getLoss: String = $(loss) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index da6bcf07e4742..a5873d03b4161 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path @@ -32,9 +32,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator +import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggregator} import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} -import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -44,8 +44,9 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.VersionUtils.majorMinorVersion /** * Params for linear regression. @@ -53,7 +54,7 @@ import org.apache.spark.storage.StorageLevel private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth { + with HasAggregationDepth with HasLoss { import LinearRegression._ @@ -69,25 +70,105 @@ private[regression] trait LinearRegressionParams extends PredictorParams "The solver algorithm for optimization. Supported options: " + s"${supportedSolvers.mkString(", ")}. (Default auto)", ParamValidators.inArray[String](supportedSolvers)) + + /** + * The loss function to be optimized. + * Supported options: "squaredError" and "huber". + * Default: "squaredError" + * + * @group param + */ + @Since("2.3.0") + final override val loss: Param[String] = new Param[String](this, "loss", "The loss function to" + + s" be optimized. Supported options: ${supportedLosses.mkString(", ")}. (Default squaredError)", + ParamValidators.inArray[String](supportedLosses)) + + /** + * The shape parameter to control the amount of robustness. Must be > 1.0. + * At larger values of epsilon, the huber criterion becomes more similar to least squares + * regression; for small values of epsilon, the criterion is more similar to L1 regression. + * Default is 1.35 to get as much robustness as possible while retaining + * 95% statistical efficiency for normally distributed data. It matches sklearn + * HuberRegressor and is "M" from + * A robust hybrid of lasso and ridge regression. + * Only valid when "loss" is "huber". + * + * @group expertParam + */ + @Since("2.3.0") + final val epsilon = new DoubleParam(this, "epsilon", "The shape parameter to control the " + + "amount of robustness. Must be > 1.0.", ParamValidators.gt(1.0)) + + /** @group getExpertParam */ + @Since("2.3.0") + def getEpsilon: Double = $(epsilon) + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + if ($(loss) == Huber) { + require($(solver)!= Normal, "LinearRegression with huber loss doesn't support " + + "normal solver, please change solver to auto or l-bfgs.") + require($(elasticNetParam) == 0.0, "LinearRegression with huber loss only supports " + + s"L2 regularization, but got elasticNetParam = $getElasticNetParam.") + + } + super.validateAndTransformSchema(schema, fitting, featuresDataType) + } } /** * Linear regression. * - * The learning objective is to minimize the squared error, with regularization. - * The specific squared error loss function used is: - * - *
    - * $$ - * L = 1/2n ||A coefficients - y||^2^ - * $$ - *
    + * The learning objective is to minimize the specified loss function, with regularization. + * This supports two kinds of loss: + * - squaredError (a.k.a squared loss) + * - huber (a hybrid of squared error for relatively small errors and absolute error for + * relatively large ones, and we estimate the scale parameter from training data) * * This supports multiple types of regularization: * - none (a.k.a. ordinary least squares) * - L2 (ridge regression) * - L1 (Lasso) * - L2 + L1 (elastic net) + * + * The squared error objective function is: + * + *
    + * $$ + * \begin{align} + * \min_{w}\frac{1}{2n}{\sum_{i=1}^n(X_{i}w - y_{i})^{2} + + * \lambda\left[\frac{1-\alpha}{2}{||w||_{2}}^{2} + \alpha{||w||_{1}}\right]} + * \end{align} + * $$ + *
    + * + * The huber objective function is: + * + *
    + * $$ + * \begin{align} + * \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma + + * H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + \frac{1}{2}\lambda {||w||_2}^2} + * \end{align} + * $$ + *
    + * + * where + * + *
    + * $$ + * \begin{align} + * H_m(z) = \begin{cases} + * z^2, & \text {if } |z| < \epsilon, \\ + * 2\epsilon|z| - \epsilon^2, & \text{otherwise} + * \end{cases} + * \end{align} + * $$ + *
    + * + * Note: Fitting with huber loss only supports none and L2 regularization. */ @Since("1.3.0") class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) @@ -142,6 +223,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * For alpha in (0,1), the penalty is a combination of L1 and L2. * Default is 0.0 which is an L2 penalty. * + * Note: Fitting with huber loss only supports None and L2 regularization, + * so throws exception if this param is non-zero value. + * * @group setParam */ @Since("1.4.0") @@ -190,6 +274,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String * The Normal Equations solver will be used when possible, but this will automatically fall * back to iterative optimization methods when needed. * + * Note: Fitting with huber loss doesn't support normal solver, + * so throws exception if this param was set with "normal". * @group setParam */ @Since("1.6.0") @@ -208,6 +294,26 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Sets the value of param [[loss]]. + * Default is "squaredError". + * + * @group setParam + */ + @Since("2.3.0") + def setLoss(value: String): this.type = set(loss, value) + setDefault(loss -> SquaredError) + + /** + * Sets the value of param [[epsilon]]. + * Default is 1.35. + * + * @group setExpertParam + */ + @Since("2.3.0") + def setEpsilon(value: Double): this.type = set(epsilon, value) + setDefault(epsilon -> 1.35) + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size @@ -220,12 +326,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, tol, - elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) + instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, tol, elasticNetParam, + fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, epsilon) instr.logNumFeatures(numFeatures) - if (($(solver) == Auto && - numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) { + if ($(loss) == SquaredError && (($(solver) == Auto && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal)) { // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) @@ -330,12 +436,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. - val effectiveRegParam = $(regParam) / yStd + val effectiveRegParam = $(loss) match { + case SquaredError => $(regParam) / yStd + case Huber => $(regParam) + } val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), - bcFeaturesStd, bcFeaturesMean)(_) val getFeaturesStd = (j: Int) => if (j >= 0 && j < numFeatures) featuresStd(j) else 0.0 val regularization = if (effectiveL2RegParam != 0.0) { val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures @@ -344,33 +451,58 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { None } - val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, - $(aggregationDepth)) - val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - } else { - val standardizationParam = $(standardization) - def effectiveL1RegFun = (index: Int) => { - if (standardizationParam) { - effectiveL1RegParam + val costFun = $(loss) match { + case SquaredError => + val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) + case Huber => + val getAggregatorFunc = new HuberAggregator($(fitIntercept), $(epsilon), bcFeaturesStd)(_) + new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth)) + } + + val optimizer = $(loss) match { + case SquaredError => + if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0 + val standardizationParam = $(standardization) + def effectiveL1RegFun = (index: Int) => { + if (standardizationParam) { + effectiveL1RegParam + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0 + } + } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) } - } - new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) + case Huber => + val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1 + val lowerBounds = BDV[Double](Array.fill(dim)(Double.MinValue)) + // Optimize huber loss in space "\sigma > 0" + lowerBounds(dim - 1) = Double.MinPositiveValue + val upperBounds = BDV[Double](Array.fill(dim)(Double.MaxValue)) + new BreezeLBFGSB(lowerBounds, upperBounds, $(maxIter), 10, $(tol)) + } + + val initialValues = $(loss) match { + case SquaredError => + Vectors.zeros(numFeatures) + case Huber => + val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1 + Vectors.dense(Array.fill(dim)(1.0)) } - val initialCoefficients = Vectors.zeros(numFeatures) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficients.asBreeze.toDenseVector) + initialValues.asBreeze.toDenseVector) - val (coefficients, objectiveHistory) = { + val (coefficients, intercept, scale, objectiveHistory) = { /* Note that in Linear Regression, the objective history (loss + regularization) returned from optimizer is computed in the scaled space given by the following formula. @@ -396,35 +528,54 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String bcFeaturesMean.destroy(blocking = false) bcFeaturesStd.destroy(blocking = false) + val parameters = state.x.toArray.clone() + /* The coefficients are trained in the scaled space; we're converting them back to the original space. */ - val rawCoefficients = state.x.toArray.clone() + val rawCoefficients: Array[Double] = $(loss) match { + case SquaredError => parameters + case Huber => parameters.slice(0, numFeatures) + } + var i = 0 val len = rawCoefficients.length + val multiplier = $(loss) match { + case SquaredError => yStd + case Huber => 1.0 + } while (i < len) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / featuresStd(i) else 0.0 } i += 1 } - (Vectors.dense(rawCoefficients).compressed, arrayBuilder.result()) - } + val interceptValue: Double = if ($(fitIntercept)) { + $(loss) match { + case SquaredError => + /* + The intercept of squared error in R's GLMNET is computed using closed form + after the coefficients are converged. See the following discussion for detail. + http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + */ + yMean - dot(Vectors.dense(rawCoefficients), Vectors.dense(featuresMean)) + case Huber => parameters(numFeatures) + } + } else { + 0.0 + } - /* - The intercept in R's GLMNET is computed using closed form after the coefficients are - converged. See the following discussion for detail. - http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - */ - val intercept = if ($(fitIntercept)) { - yMean - dot(coefficients, Vectors.dense(featuresMean)) - } else { - 0.0 + val scaleValue: Double = $(loss) match { + case SquaredError => 1.0 + case Huber => parameters.last + } + + (Vectors.dense(rawCoefficients).compressed, interceptValue, scaleValue, arrayBuilder.result()) } if (handlePersistence) instances.unpersist() - val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) + val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept, scale)) // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() @@ -471,6 +622,15 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { /** Set of solvers that LinearRegression supports. */ private[regression] val supportedSolvers = Array(Auto, Normal, LBFGS) + + /** String name for "squaredError". */ + private[regression] val SquaredError = "squaredError" + + /** String name for "huber". */ + private[regression] val Huber = "huber" + + /** Set of loss function names that LinearRegression supports. */ + private[regression] val supportedLosses = Array(SquaredError, Huber) } /** @@ -480,10 +640,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { class LinearRegressionModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("2.0.0") val coefficients: Vector, - @Since("1.3.0") val intercept: Double) + @Since("1.3.0") val intercept: Double, + @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { + def this(uid: String, coefficients: Vector, intercept: Double) = + this(uid, coefficients, intercept, 1.0) + private var trainingSummary: Option[LinearRegressionTrainingSummary] = None override val numFeatures: Int = coefficients.size @@ -570,13 +734,13 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) extends MLWriter with Logging { - private case class Data(intercept: Double, coefficients: Vector) + private case class Data(intercept: Double, coefficients: Vector, scale: Double) override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients - val data = Data(instance.intercept, instance.coefficients) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -592,11 +756,20 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) - val Row(intercept: Double, coefficients: Vector) = - MLUtils.convertVectorColumnsToML(data, "coefficients") - .select("intercept", "coefficients") - .head() - val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) + val model = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { + // Spark 2.2 and before + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() + new LinearRegressionModel(metadata.uid, coefficients, intercept) + } else { + // Spark 2.3 and later + val Row(intercept: Double, coefficients: Vector, scale: Double) = + data.select("intercept", "coefficients", "scale").head() + new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) + } DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala new file mode 100644 index 0000000000000..718ffa230a749 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import DifferentiableLossAggregatorSuite.getRegressionSummarizers + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantFeatureFiltered: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantFeatureFiltered = Array( + Instance(0.0, 0.1, Vectors.dense(2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0)), + Instance(2.0, 0.3, Vectors.dense(0.5)) + ) + } + + /** Get summary statistics for some data and create a new HuberAggregator. */ + private def getNewAggregator( + instances: Array[Instance], + parameters: Vector, + fitIntercept: Boolean, + epsilon: Double): HuberAggregator = { + val (featuresSummarizer, _) = getRegressionSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val bcParameters = spark.sparkContext.broadcast(parameters) + new HuberAggregator(fitIntercept, epsilon, bcFeaturesStd)(bcParameters) + } + + test("aggregator add method should check input size") { + val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val agg = getNewAggregator(instances, parameters, fitIntercept = true, epsilon = 1.35) + withClue("HuberAggregator features dimension must match parameters dimension") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, 1.0, Vectors.dense(2.0))) + } + } + } + + test("negative weight") { + val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val agg = getNewAggregator(instances, parameters, fitIntercept = true, epsilon = 1.35) + withClue("HuberAggregator does not support negative instance weights.") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0))) + } + } + } + + test("check sizes") { + val paramWithIntercept = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val paramWithoutIntercept = Vectors.dense(1.0, 2.0, 4.0) + val aggIntercept = getNewAggregator(instances, paramWithIntercept, + fitIntercept = true, epsilon = 1.35) + val aggNoIntercept = getNewAggregator(instances, paramWithoutIntercept, + fitIntercept = false, epsilon = 1.35) + instances.foreach(aggIntercept.add) + instances.foreach(aggNoIntercept.add) + + assert(aggIntercept.gradient.size === 4) + assert(aggNoIntercept.gradient.size === 3) + } + + test("check correctness") { + val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val numFeatures = 2 + val (featuresSummarizer, _) = getRegressionSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val epsilon = 1.35 + val weightSum = instances.map(_.weight).sum + + val agg = getNewAggregator(instances, parameters, fitIntercept = true, epsilon) + instances.foreach(agg.add) + + // compute expected loss sum + val coefficients = parameters.toArray.slice(0, 2) + val intercept = parameters(2) + val sigma = parameters(3) + val stdCoef = coefficients.indices.map(i => coefficients(i) / featuresStd(i)).toArray + val lossSum = instances.map { case Instance(label, weight, features) => + val margin = BLAS.dot(Vectors.dense(stdCoef), features) + intercept + val linearLoss = label - margin + if (math.abs(linearLoss) <= sigma * epsilon) { + 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma) + } else { + 0.5 * weight * (sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon) + } + }.sum + val loss = lossSum / weightSum + + // compute expected gradients + val gradientCoef = new Array[Double](numFeatures + 2) + instances.foreach { case Instance(label, weight, features) => + val margin = BLAS.dot(Vectors.dense(stdCoef), features) + intercept + val linearLoss = label - margin + if (math.abs(linearLoss) <= sigma * epsilon) { + features.toArray.indices.foreach { i => + gradientCoef(i) += + -1.0 * weight * (linearLoss / sigma) * (features(i) / featuresStd(i)) + } + gradientCoef(2) += -1.0 * weight * (linearLoss / sigma) + gradientCoef(3) += 0.5 * weight * (1.0 - math.pow(linearLoss / sigma, 2.0)) + } else { + val sign = if (linearLoss >= 0) -1.0 else 1.0 + features.toArray.indices.foreach { i => + gradientCoef(i) += weight * sign * epsilon * (features(i) / featuresStd(i)) + } + gradientCoef(2) += weight * sign * epsilon + gradientCoef(3) += 0.5 * weight * (1.0 - epsilon * epsilon) + } + } + val gradient = Vectors.dense(gradientCoef.map(_ / weightSum)) + + assert(loss ~== agg.loss relTol 0.01) + assert(gradient ~== agg.gradient relTol 0.01) + } + + test("check with zero standard deviation") { + val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0) + val parametersFiltered = Vectors.dense(2.0, 3.0, 4.0) + val aggConstantFeature = getNewAggregator(instancesConstantFeature, parameters, + fitIntercept = true, epsilon = 1.35) + val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, + parametersFiltered, fitIntercept = true, epsilon = 1.35) + instances.foreach(aggConstantFeature.add) + instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add) + // constant features should not affect gradient + def validateGradient(grad: Vector, gradFiltered: Vector): Unit = { + assert(grad(0) === 0.0) + assert(grad(1) ~== gradFiltered(0) relTol 0.01) + } + + validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index aec5ac0e75896..9bb2895858f33 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -41,6 +41,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { @transient var datasetWithWeight: DataFrame = _ @transient var datasetWithWeightConstantLabel: DataFrame = _ @transient var datasetWithWeightZeroLabel: DataFrame = _ + @transient var datasetWithOutlier: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -107,6 +108,16 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2).toDF() + + datasetWithOutlier = { + val inlierData = LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 900, seed, eps = 0.1) + val outlierData = LinearDataGenerator.generateLinearInput( + intercept = -2.1, weights = Array(0.6, -1.2), xMean = Array(0.9, -1.3), + xVariance = Array(1.5, 0.8), nPoints = 100, seed, eps = 0.1) + sc.parallelize(inlierData ++ outlierData, 2).map(_.asML).toDF() + } } /** @@ -127,6 +138,10 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { datasetWithSparseFeature.rdd.map { case Row(label: Double, features: Vector) => label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithSparseFeature") + + datasetWithOutlier.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithOutlier") } test("params") { @@ -144,7 +159,9 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(lir.getElasticNetParam === 0.0) assert(lir.getFitIntercept) assert(lir.getStandardization) - assert(lir.getSolver == "auto") + assert(lir.getSolver === "auto") + assert(lir.getLoss === "squaredError") + assert(lir.getEpsilon === 1.35) val model = lir.fit(datasetWithDenseFeature) MLTestingUtils.checkCopyAndUids(lir, model) @@ -160,11 +177,27 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") assert(model.intercept !== 0.0) + assert(model.scale === 1.0) assert(model.hasParent) val numFeatures = datasetWithDenseFeature.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) } + test("linear regression: illegal params") { + withClue("LinearRegression with huber loss only supports L2 regularization") { + intercept[IllegalArgumentException] { + new LinearRegression().setLoss("huber").setElasticNetParam(0.5) + .fit(datasetWithDenseFeature) + } + } + + withClue("LinearRegression with huber loss doesn't support normal solver") { + intercept[IllegalArgumentException] { + new LinearRegression().setLoss("huber").setSolver("normal").fit(datasetWithDenseFeature) + } + } + } + test("linear regression handles singular matrices") { // check for both constant columns with intercept (zero std) and collinear val singularDataConstantColumn = sc.parallelize(Seq( @@ -837,6 +870,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { (1.0, 0.21, true, true) ) + // For squaredError loss for (solver <- Seq("auto", "l-bfgs", "normal"); (elasticNetParam, regParam, fitIntercept, standardization) <- testParams) { val estimator = new LinearRegression() @@ -852,6 +886,22 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } + + // For huber loss + for ((_, regParam, fitIntercept, standardization) <- testParams) { + val estimator = new LinearRegression() + .setLoss("huber") + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setRegParam(regParam) + MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( + datasetWithOutlier.as[LabeledPoint], estimator, modelEquals) + MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( + datasetWithOutlier.as[LabeledPoint], estimator, numClasses, modelEquals, + outlierRatio = 3) + MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( + datasetWithOutlier.as[LabeledPoint], estimator, modelEquals, seed) + } } test("linear regression model with l-bfgs with big feature datasets") { @@ -1004,6 +1054,198 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } } + + test("linear regression (huber loss) with intercept without regularization") { + val trainer1 = (new LinearRegression).setLoss("huber") + .setFitIntercept(true).setStandardization(true) + val trainer2 = (new LinearRegression).setLoss("huber") + .setFitIntercept(true).setStandardization(false) + + val model1 = trainer1.fit(datasetWithOutlier) + val model2 = trainer2.fit(datasetWithOutlier) + + /* + Using the following Python code to load the data and train the model using + scikit-learn package. + + import pandas as pd + import numpy as np + from sklearn.linear_model import HuberRegressor + df = pd.read_csv("path", header = None) + X = df[df.columns[1:3]] + y = np.array(df[df.columns[0]]) + huber = HuberRegressor(fit_intercept=True, alpha=0.0, max_iter=100, epsilon=1.35) + huber.fit(X, y) + + >>> huber.coef_ + array([ 4.68998007, 7.19429011]) + >>> huber.intercept_ + 6.3002404351083037 + >>> huber.scale_ + 0.077810159205220747 + */ + val coefficientsPy = Vectors.dense(4.68998007, 7.19429011) + val interceptPy = 6.30024044 + val scalePy = 0.07781016 + + assert(model1.coefficients ~= coefficientsPy relTol 1E-3) + assert(model1.intercept ~== interceptPy relTol 1E-3) + assert(model1.scale ~== scalePy relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.coefficients ~= coefficientsPy relTol 1E-3) + assert(model2.intercept ~== interceptPy relTol 1E-3) + assert(model2.scale ~== scalePy relTol 1E-3) + } + + test("linear regression (huber loss) without intercept without regularization") { + val trainer1 = (new LinearRegression).setLoss("huber") + .setFitIntercept(false).setStandardization(true) + val trainer2 = (new LinearRegression).setLoss("huber") + .setFitIntercept(false).setStandardization(false) + + val model1 = trainer1.fit(datasetWithOutlier) + val model2 = trainer2.fit(datasetWithOutlier) + + /* + huber = HuberRegressor(fit_intercept=False, alpha=0.0, max_iter=100, epsilon=1.35) + huber.fit(X, y) + + >>> huber.coef_ + array([ 6.71756703, 5.08873222]) + >>> huber.intercept_ + 0.0 + >>> huber.scale_ + 2.5560209922722317 + */ + val coefficientsPy = Vectors.dense(6.71756703, 5.08873222) + val interceptPy = 0.0 + val scalePy = 2.55602099 + + assert(model1.coefficients ~= coefficientsPy relTol 1E-3) + assert(model1.intercept === interceptPy) + assert(model1.scale ~== scalePy relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.coefficients ~= coefficientsPy relTol 1E-3) + assert(model2.intercept === interceptPy) + assert(model2.scale ~== scalePy relTol 1E-3) + } + + test("linear regression (huber loss) with intercept with L2 regularization") { + val trainer1 = (new LinearRegression).setLoss("huber") + .setFitIntercept(true).setRegParam(0.21).setStandardization(true) + val trainer2 = (new LinearRegression).setLoss("huber") + .setFitIntercept(true).setRegParam(0.21).setStandardization(false) + + val model1 = trainer1.fit(datasetWithOutlier) + val model2 = trainer2.fit(datasetWithOutlier) + + /* + Since scikit-learn HuberRegressor does not support standardization, + we do it manually out of the estimator. + + xStd = np.std(X, axis=0) + scaledX = X / xStd + huber = HuberRegressor(fit_intercept=True, alpha=210, max_iter=100, epsilon=1.35) + huber.fit(scaledX, y) + + >>> np.array(huber.coef_ / xStd) + array([ 1.97732633, 3.38816722]) + >>> huber.intercept_ + 3.7527581430531227 + >>> huber.scale_ + 3.787363673371801 + */ + val coefficientsPy1 = Vectors.dense(1.97732633, 3.38816722) + val interceptPy1 = 3.75275814 + val scalePy1 = 3.78736367 + + assert(model1.coefficients ~= coefficientsPy1 relTol 1E-2) + assert(model1.intercept ~== interceptPy1 relTol 1E-2) + assert(model1.scale ~== scalePy1 relTol 1E-2) + + /* + huber = HuberRegressor(fit_intercept=True, alpha=210, max_iter=100, epsilon=1.35) + huber.fit(X, y) + + >>> huber.coef_ + array([ 1.73346444, 3.63746999]) + >>> huber.intercept_ + 4.3017134790781739 + >>> huber.scale_ + 3.6472742809286793 + */ + val coefficientsPy2 = Vectors.dense(1.73346444, 3.63746999) + val interceptPy2 = 4.30171347 + val scalePy2 = 3.64727428 + + assert(model2.coefficients ~= coefficientsPy2 relTol 1E-3) + assert(model2.intercept ~== interceptPy2 relTol 1E-3) + assert(model2.scale ~== scalePy2 relTol 1E-3) + } + + test("linear regression (huber loss) without intercept with L2 regularization") { + val trainer1 = (new LinearRegression).setLoss("huber") + .setFitIntercept(false).setRegParam(0.21).setStandardization(true) + val trainer2 = (new LinearRegression).setLoss("huber") + .setFitIntercept(false).setRegParam(0.21).setStandardization(false) + + val model1 = trainer1.fit(datasetWithOutlier) + val model2 = trainer2.fit(datasetWithOutlier) + + /* + Since scikit-learn HuberRegressor does not support standardization, + we do it manually out of the estimator. + + xStd = np.std(X, axis=0) + scaledX = X / xStd + huber = HuberRegressor(fit_intercept=False, alpha=210, max_iter=100, epsilon=1.35) + huber.fit(scaledX, y) + + >>> np.array(huber.coef_ / xStd) + array([ 2.59679008, 2.26973102]) + >>> huber.intercept_ + 0.0 + >>> huber.scale_ + 4.5766311924091791 + */ + val coefficientsPy1 = Vectors.dense(2.59679008, 2.26973102) + val interceptPy1 = 0.0 + val scalePy1 = 4.57663119 + + assert(model1.coefficients ~= coefficientsPy1 relTol 1E-2) + assert(model1.intercept === interceptPy1) + assert(model1.scale ~== scalePy1 relTol 1E-2) + + /* + huber = HuberRegressor(fit_intercept=False, alpha=210, max_iter=100, epsilon=1.35) + huber.fit(X, y) + + >>> huber.coef_ + array([ 2.28423908, 2.25196887]) + >>> huber.intercept_ + 0.0 + >>> huber.scale_ + 4.5979643506051753 + */ + val coefficientsPy2 = Vectors.dense(2.28423908, 2.25196887) + val interceptPy2 = 0.0 + val scalePy2 = 4.59796435 + + assert(model2.coefficients ~= coefficientsPy2 relTol 1E-3) + assert(model2.intercept === interceptPy2) + assert(model2.scale ~== scalePy2 relTol 1E-3) + } + + test("huber loss model match squared error for large epsilon") { + val trainer1 = new LinearRegression().setLoss("huber").setEpsilon(1E5) + val model1 = trainer1.fit(datasetWithOutlier) + val trainer2 = new LinearRegression() + val model2 = trainer2.fit(datasetWithOutlier) + assert(model1.coefficients ~== model2.coefficients relTol 1E-3) + assert(model1.intercept ~== model2.intercept relTol 1E-3) + } } object LinearRegressionSuite { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9be01f617217b..9902fedb65d59 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1065,6 +1065,11 @@ object MimaExcludes { // [SPARK-21680][ML][MLLIB]optimzie Vector coompress ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize") + ) ++ Seq( + // [SPARK-3181][ML]Implement huber loss for LinearRegression. + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss") ) } From f8c7c1f21aa9d1fd38b584ca8c4adf397966e9f7 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 13 Dec 2017 22:31:39 -0800 Subject: [PATCH 425/712] [SPARK-22732] Add Structured Streaming APIs to DataSourceV2 ## What changes were proposed in this pull request? This PR provides DataSourceV2 API support for structured streaming, including new pieces needed to support continuous processing [SPARK-20928]. High level summary: - DataSourceV2 includes new mixins to support micro-batch and continuous reads and writes. For reads, we accept an optional user specified schema rather than using the ReadSupportWithSchema model, because doing so would severely complicate the interface. - DataSourceV2Reader includes new interfaces to read a specific microbatch or read continuously from a given offset. These follow the same setter pattern as the existing Supports* mixins so that they can work with SupportsScanUnsafeRow. - DataReader (the per-partition reader) has a new subinterface ContinuousDataReader only for continuous processing. This reader has a special method to check progress, and next() blocks for new input rather than returning false. - Offset, an abstract representation of position in a streaming query, is ported to the public API. (Each type of reader will define its own Offset implementation.) - DataSourceV2Writer has a new subinterface ContinuousWriter only for continuous processing. Commits to this interface come tagged with an epoch number, as the execution engine will continue to produce new epoch commits as the task continues indefinitely. Note that this PR does not propose to change the existing DataSourceV2 batch API, or deprecate the existing streaming source/sink internal APIs in spark.sql.execution.streaming. ## How was this patch tested? Toy implementations of the new interfaces with unit tests. Author: Jose Torres Closes #19925 from joseph-torres/continuous-api. --- .../spark/sql/kafka010/KafkaSource.scala | 1 + .../sql/kafka010/KafkaSourceOffset.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 + .../sql/sources/v2/ContinuousReadSupport.java | 42 +++++ .../sources/v2/ContinuousWriteSupport.java | 54 ++++++ .../sql/sources/v2/DataSourceV2Options.java | 8 + .../sql/sources/v2/MicroBatchReadSupport.java | 52 +++++ .../sources/v2/MicroBatchWriteSupport.java | 58 ++++++ .../v2/reader/ContinuousDataReader.java | 36 ++++ .../sources/v2/reader/ContinuousReader.java | 68 +++++++ .../sources/v2/reader/MicroBatchReader.java | 64 +++++++ .../spark/sql/sources/v2/reader/Offset.java | 60 ++++++ .../sources/v2/reader/PartitionOffset.java | 30 +++ .../sources/v2/writer/ContinuousWriter.java | 41 ++++ .../streaming/BaseStreamingSink.java | 27 +++ .../streaming/BaseStreamingSource.java | 37 ++++ .../streaming/FileStreamSource.scala | 1 + .../streaming/FileStreamSourceOffset.scala | 3 + .../sql/execution/streaming/LongOffset.scala | 2 + .../sql/execution/streaming/Offset.scala | 34 +--- .../sql/execution/streaming/OffsetSeq.scala | 1 + .../execution/streaming/OffsetSeqLog.scala | 1 + .../streaming/RateSourceProvider.scala | 22 ++- .../streaming/RateStreamOffset.scala | 29 +++ .../streaming/RateStreamSourceV2.scala | 162 ++++++++++++++++ .../sql/execution/streaming/Source.scala | 1 + .../execution/streaming/StreamExecution.scala | 1 + .../execution/streaming/StreamProgress.scala | 2 + .../ContinuousRateStreamSource.scala | 152 +++++++++++++++ .../sql/execution/streaming/memory.scala | 1 + .../sql/execution/streaming/memoryV2.scala | 178 ++++++++++++++++++ .../sql/execution/streaming/socket.scala | 1 + .../streaming/MemorySinkV2Suite.scala | 82 ++++++++ .../streaming/OffsetSeqLogSuite.scala | 1 + .../execution/streaming/RateSourceSuite.scala | 1 + .../streaming/RateSourceV2Suite.scala | 155 +++++++++++++++ .../sql/streaming/FileStreamSourceSuite.scala | 1 + .../spark/sql/streaming/OffsetSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 1 + .../spark/sql/streaming/StreamTest.scala | 1 + .../streaming/StreamingAggregationSuite.scala | 1 + .../StreamingQueryListenerSuite.scala | 1 + .../sql/streaming/StreamingQuerySuite.scala | 1 + .../test/DataStreamReaderWriterSuite.scala | 1 + .../sql/streaming/util/BlockingSource.scala | 3 +- 45 files changed, 1390 insertions(+), 35 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..87f31fcc20ae6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..6e24423df4abc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.execution.streaming.SerializedOffset +import org.apache.spark.sql.sources.v2.reader.Offset /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..9cac0e5ae7117 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java new file mode 100644 index 0000000000000..ae4f85820649f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.sql.sources.v2.reader.ContinuousReader; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + */ +public interface ContinuousReadSupport extends DataSourceV2 { + /** + * Creates a {@link ContinuousReader} to scan the data from this data source. + * + * @param schema the user provided schema, or empty() if none was provided + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReader createContinuousReader(Optional schema, String checkpointLocation, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java new file mode 100644 index 0000000000000..362d5f52b4d00 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for continuous stream processing. + */ +@InterfaceStability.Evolving +public interface ContinuousWriteSupport extends BaseStreamingSink { + + /** + * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done. + * + * @param queryId A unique string for the writing query. It's possible that there are many writing + * queries running at the same time, and the returned {@link DataSourceV2Writer} + * can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createContinuousWriter( + String queryId, + StructType schema, + OutputMode mode, + DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java index e98c04517c3db..ddc2acca693ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -36,6 +36,10 @@ private String toLowerCase(String key) { return key.toLowerCase(Locale.ROOT); } + public static DataSourceV2Options empty() { + return new DataSourceV2Options(new HashMap<>()); + } + public DataSourceV2Options(Map originalMap) { keyLowerCasedMap = new HashMap<>(originalMap.size()); for (Map.Entry entry : originalMap.entrySet()) { @@ -43,6 +47,10 @@ public DataSourceV2Options(Map originalMap) { } } + public Map asMap() { + return new HashMap<>(keyLowerCasedMap); + } + /** * Returns the option value to which the specified key is mapped, case-insensitively. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java new file mode 100644 index 0000000000000..442cad029d211 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.MicroBatchReader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide streaming micro-batch data reading ability. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends DataSourceV2 { + /** + * Creates a {@link MicroBatchReader} to read batches of data from this data source in a + * streaming query. + * + * The execution engine will create a micro-batch reader at the start of a streaming query, + * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then + * call stop() when the execution is complete. Note that a single query may have multiple + * executions due to restart or failure recovery. + * + * @param schema the user provided schema, or empty() if none was provided + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReader createMicroBatchReader( + Optional schema, + String checkpointLocation, + DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java new file mode 100644 index 0000000000000..63640779b955c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data from a microbatch to the data source. + */ +@InterfaceStability.Evolving +public interface MicroBatchWriteSupport extends BaseStreamingSink { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done. + * + * @param queryId A unique string for the writing query. It's possible that there are many writing + * queries running at the same time, and the returned {@link DataSourceV2Writer} + * can use this id to distinguish itself from others. + * @param epochId The uniquenumeric ID of the batch within this writing query. This is an + * incrementing counter representing a consistent set of data; the same batch may + * be started multiple times in failure recovery scenarios, but it will always + * contain the same records. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive batch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createMicroBatchWriter( + String queryId, + long epochId, + StructType schema, + OutputMode mode, + DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java new file mode 100644 index 0000000000000..11b99a93f1494 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.sql.sources.v2.reader.PartitionOffset; + +import java.io.IOException; + +/** + * A variation on {@link DataReader} for use with streaming in continuous processing mode. + */ +public interface ContinuousDataReader extends DataReader { + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java new file mode 100644 index 0000000000000..1baf82c2df762 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.sql.sources.v2.reader.PartitionOffset; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; + +import java.util.Optional; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to allow reading in a continuous processing mode stream. + * + * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + */ +public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { + /** + * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to + * a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Set the desired start offset for read tasks created from this reader. The scan will start + * from the first record after the provided offset, or from an implementation-defined inferred + * starting point if no offset is provided. + */ + void setOffset(Optional start); + + /** + * Return the specified or inferred start offset for this reader. + * + * @throws IllegalStateException if setOffset has not been called + */ + Offset getStartOffset(); + + /** + * The execution engine will call this method in every epoch to determine if new read tasks need + * to be generated, which may be required if for example the underlying source system has had + * partitions added or removed. + * + * If true, the query will be shut down and restarted with a new reader. + */ + default boolean needsReconfiguration() { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java new file mode 100644 index 0000000000000..438e3f55b7bcf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.sql.sources.v2.reader.Offset; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; + +import java.util.Optional; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to indicate they allow micro-batch streaming reads. + */ +public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { + /** + * Set the desired offset range for read tasks created from this reader. Read tasks will + * generate only data within (`start`, `end`]; that is, from the first record after `start` to + * the record with offset `end`. + * + * @param start The initial offset to scan from. If not specified, scan from an + * implementation-specified start point, such as the earliest available record. + * @param end The last offset to include in the scan. If not specified, scan up to an + * implementation-defined endpoint, such as the last available offset + * or the start offset plus a target batch size. + */ + void setOffsetRange(Optional start, Optional end); + + /** + * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset + * for this reader. + * + * @throws IllegalStateException if setOffsetRange has not been called + */ + Offset getStartOffset(); + + /** + * Return the specified (if explicitly set through setOffsetRange) or inferred end offset + * for this reader. + * + * @throws IllegalStateException if setOffsetRange has not been called + */ + Offset getEndOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java new file mode 100644 index 0000000000000..1ebd35356f1a3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +/** + * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. + * During execution, Offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Sources should provide an Offset implementation which they can use to + * reconstruct the stream position where the offset was taken. + */ +public abstract class Offset { + /** + * A JSON-serialized representation of an Offset that is + * used for saving offsets to the offset log. + * Note: We assume that equivalent/equal offsets serialize to + * identical JSON strings. + * + * @return JSON string encoding + */ + public abstract String json(); + + /** + * Equality based on JSON string representation. We leverage the + * JSON representation for normalization between the Offset's + * in memory and on disk representations. + */ + @Override + public boolean equals(Object obj) { + if (obj instanceof Offset) { + return this.json().equals(((Offset) obj).json()); + } else { + return false; + } + } + + @Override + public int hashCode() { + return this.json().hashCode(); + } + + @Override + public String toString() { + return this.json(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java new file mode 100644 index 0000000000000..07826b6688476 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +/** + * Used for per-partition offsets in continuous processing. ContinuousReader implementations will + * provide a method to merge these into a global Offset. + * + * These offsets must be serializable. + */ +public interface PartitionOffset extends Serializable { + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java new file mode 100644 index 0000000000000..618f47ed79ca5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A {@link DataSourceV2Writer} for use with continuous stream processing. + */ +@InterfaceStability.Evolving +public interface ContinuousWriter extends DataSourceV2Writer { + /** + * Commits this writing job for the specified epoch with a list of commit messages. The commit + * messages are collected from successful data writers and are produced by + * {@link DataWriter#commit()}. + * + * If this method fails (by throwing an exception), this writing job is considered to have been + * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + */ + void commit(long epochId, WriterCommitMessage[] messages); + + default void commit(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Commit without epoch should not be called with ContinuousWriter"); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java new file mode 100644 index 0000000000000..ac96c2765368f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming; + +/** + * The shared interface between V1 and V2 streaming sinks. + * + * This is a temporary interface for compatibility during migration. It should not be implemented + * directly, and will be removed in future versions. + */ +public interface BaseStreamingSink { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java new file mode 100644 index 0000000000000..3a02cbfe7afe3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming; + +import org.apache.spark.sql.sources.v2.reader.Offset; + +/** + * The shared interface between V1 streaming sources and V2 streaming readers. + * + * This is a temporary interface for compatibility during migration. It should not be implemented + * directly, and will be removed in future versions. + */ +public interface BaseStreamingSource { + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); + + /** Stop this source and free any resources it has allocated. */ + void stop(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db84757..a33b785126765 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala index 06d0fe6c18c1e..431e5b99e3e98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala @@ -22,8 +22,11 @@ import scala.util.control.Exception._ import org.json4s.NoTypeHints import org.json4s.jackson.Serialization +import org.apache.spark.sql.sources.v2.reader.Offset + /** * Offset for the [[FileStreamSource]]. + * * @param logOffset Position in the [[FileStreamSourceLog]] */ case class FileStreamSourceOffset(logOffset: Long) extends Offset { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb8..7ea31462ca7b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.reader.Offset + /** * A simple offset for sources that produce a single linear stream of data. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala index 4efcee0f8f9d6..73f0c6221c5c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -17,38 +17,8 @@ package org.apache.spark.sql.execution.streaming -/** - * An offset is a monotonically increasing metric used to track progress in the computation of a - * stream. Since offsets are retrieved from a [[Source]] by a single thread, we know the global - * ordering of two [[Offset]] instances. We do assume that if two offsets are `equal` then no - * new data has arrived. - */ -abstract class Offset { - - /** - * Equality based on JSON string representation. We leverage the - * JSON representation for normalization between the Offset's - * in memory and on disk representations. - */ - override def equals(obj: Any): Boolean = obj match { - case o: Offset => this.json == o.json - case _ => false - } - - override def hashCode(): Int = this.json.hashCode +import org.apache.spark.sql.sources.v2.reader.Offset - override def toString(): String = this.json.toString - - /** - * A JSON-serialized representation of an Offset that is - * used for saving offsets to the offset log. - * Note: We assume that equivalent/equal offsets serialize to - * identical JSON strings. - * - * @return JSON string encoding - */ - def json: String -} /** * Used when loading a JSON serialized offset from external storage. @@ -58,3 +28,5 @@ abstract class Offset { * that accepts a [[SerializedOffset]] for doing the conversion. */ case class SerializedOffset(override val json: String) extends Offset + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 4e0a468b962a2..dcc5935890c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,6 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.sources.v2.reader.Offset /** * An ordered collection of offsets, used to track the progress of processing data from one or more diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index e3f4abcf9f1dc..bfdbc65296165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.v2.reader.Offset /** * This class is used to log offsets to persistent files in HDFS. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 077a4778e34a8..50671a46599e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import java.io._ import java.nio.charset.StandardCharsets +import java.util.Optional import java.util.concurrent.TimeUnit import org.apache.commons.io.IOUtils @@ -28,7 +29,10 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader, Offset} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} @@ -46,7 +50,8 @@ import org.apache.spark.util.{ManualClock, SystemClock} * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister + with DataSourceV2 with MicroBatchReadSupport with ContinuousReadSupport{ override def sourceSchema( sqlContext: SQLContext, @@ -100,6 +105,21 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing ) } + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + new RateStreamV2Reader(options) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): ContinuousReader = { + new ContinuousRateStreamReader(options) + } + override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala new file mode 100644 index 0000000000000..13679dfbe446b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.sources.v2.reader.Offset + +case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, (Long, Long)]) + extends Offset { + implicit val defaultFormats: DefaultFormats = DefaultFormats + override val json = Serialization.write(partitionToValueAndRunTimeMs) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala new file mode 100644 index 0000000000000..102551c238bfb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.SystemClock + +class RateStreamV2Reader(options: DataSourceV2Options) + extends MicroBatchReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val clock = new SystemClock + + private val numPartitions = + options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + private val rowsPerSecond = + options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + + // The interval (in milliseconds) between rows in each partition. + // e.g. if there are 4 global rows per second, and 2 partitions, each partition + // should output rows every (1000 * 2 / 4) = 500 ms. + private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond + + override def readSchema(): StructType = { + StructType( + StructField("timestamp", TimestampType, false) :: + StructField("value", LongType, false) :: Nil) + } + + val creationTimeMs = clock.getTimeMillis() + + private var start: RateStreamOffset = _ + private var end: RateStreamOffset = _ + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse( + RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) + .asInstanceOf[RateStreamOffset] + + this.end = end.orElse { + val currentTime = clock.getTimeMillis() + RateStreamOffset( + this.start.partitionToValueAndRunTimeMs.map { + case startOffset @ (part, (currentVal, currentReadTime)) => + // Calculate the number of rows we should advance in this partition (based on the + // current time), and output a corresponding offset. + val readInterval = currentTime - currentReadTime + val numNewRows = readInterval / msPerPartitionBetweenRows + if (numNewRows <= 0) { + startOffset + } else { + (part, + (currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + } + } + ) + }.asInstanceOf[RateStreamOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + } + + override def createReadTasks(): java.util.List[ReadTask[Row]] = { + val startMap = start.partitionToValueAndRunTimeMs + val endMap = end.partitionToValueAndRunTimeMs + endMap.keys.toSeq.map { part => + val (endVal, _) = endMap(part) + val (startVal, startTimeMs) = startMap(part) + + val packedRows = mutable.ListBuffer[(Long, Long)]() + var outVal = startVal + numPartitions + var outTimeMs = startTimeMs + msPerPartitionBetweenRows + while (outVal <= endVal) { + packedRows.append((outTimeMs, outVal)) + outVal += numPartitions + outTimeMs += msPerPartitionBetweenRows + } + + RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} +} + +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { + override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) +} + +class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { + var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the seq. + currentIndex += 1 + currentIndex < vals.size + } + + override def get(): Row = { + Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), + vals(currentIndex)._2) + } + + override def close(): Unit = {} +} + +object RateStreamSourceV2 { + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + + private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ((i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 311942f6dbd84..dbb408ffc98d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 406560c260f07..16063c02ce06b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index a3f3662e6f4c9..770db401c9fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming import scala.collection.{immutable, GenTraversableOnce} +import org.apache.spark.sql.sources.v2.reader.Offset + /** * A helper class that looks like a Map[Source, Offset]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala new file mode 100644 index 0000000000000..77fc26730e52c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.collection.JavaConverters._ + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} + +case class ContinuousRateStreamPartitionOffset( + partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset + +class ContinuousRateStreamReader(options: DataSourceV2Options) + extends ContinuousReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val creationTime = System.currentTimeMillis() + + val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + assert(offsets.length == numPartitions) + val tuples = offsets.map { + case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => (i, (currVal, nextRead)) + } + RateStreamOffset(Map(tuples: _*)) + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + } + + override def readSchema(): StructType = RateSourceProvider.SCHEMA + + private var offset: Offset = _ + + override def setOffset(offset: java.util.Optional[Offset]): Unit = { + this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + } + + override def getStartOffset(): Offset = offset + + override def createReadTasks(): java.util.List[ReadTask[Row]] = { + val partitionStartMap = offset match { + case off: RateStreamOffset => off.partitionToValueAndRunTimeMs + case off => + throw new IllegalArgumentException( + s"invalid offset type ${off.getClass()} for ContinuousRateSource") + } + if (partitionStartMap.keySet.size != numPartitions) { + throw new IllegalArgumentException( + s"The previous run contained ${partitionStartMap.keySet.size} partitions, but" + + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") + } + + Range(0, numPartitions).map { i => + val start = partitionStartMap(i) + // Have each partition advance by numPartitions each row, with starting points staggered + // by their partition index. + RateStreamReadTask( + start._1, // starting row value + start._2, // starting time in ms + i, + numPartitions, + perPartitionRate) + .asInstanceOf[ReadTask[Row]] + }.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + +} + +case class RateStreamReadTask( + startValue: Long, + startTimeMs: Long, + partitionIndex: Int, + increment: Long, + rowsPerSecond: Double) + extends ReadTask[Row] { + override def createDataReader(): DataReader[Row] = + new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +} + +class RateStreamDataReader( + startValue: Long, + startTimeMs: Long, + partitionIndex: Int, + increment: Long, + rowsPerSecond: Double) + extends ContinuousDataReader[Row] { + private var nextReadTime: Long = startTimeMs + private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong + + private var currentValue = startValue + private var currentRow: Row = null + + override def next(): Boolean = { + currentValue += increment + nextReadTime += readTimeIncrement + + try { + while (System.currentTimeMillis < nextReadTime) { + Thread.sleep(nextReadTime - System.currentTimeMillis) + } + } catch { + case _: InterruptedException => + // Someone's trying to end the task; just let them. + return false + } + + currentRow = Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)), + currentValue) + + true + } + + override def get: Row = currentRow + + override def close(): Unit = {} + + override def getOffset(): PartitionOffset = + ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb4..db0717510a2cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala new file mode 100644 index 0000000000000..437040cc12472 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} +import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit + * tests and does not provide durability. + */ +class MemorySinkV2 extends DataSourceV2 + with MicroBatchWriteSupport with ContinuousWriteSupport with Logging { + + override def createMicroBatchWriter( + queryId: String, + batchId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = { + java.util.Optional.of(new MemoryWriter(this, batchId, mode)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = { + java.util.Optional.of(new ContinuousMemoryWriter(this, mode)) + } + + private case class AddedData(batchId: Long, data: Array[Row]) + + /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") + private val batches = new ArrayBuffer[AddedData]() + + /** Returns all rows that are stored in this [[Sink]]. */ + def allData: Seq[Row] = synchronized { + batches.flatMap(_.data) + } + + def latestBatchId: Option[Long] = synchronized { + batches.lastOption.map(_.batchId) + } + + def latestBatchData: Seq[Row] = synchronized { + batches.lastOption.toSeq.flatten(_.data) + } + + def toDebugString: String = synchronized { + batches.map { case AddedData(batchId, data) => + val dataStr = try data.mkString(" ") catch { + case NonFatal(e) => "[Error converting to string]" + } + s"$batchId: $dataStr" + }.mkString("\n") + } + + def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + val notCommitted = synchronized { + latestBatchId.isEmpty || batchId > latestBatchId.get + } + if (notCommitted) { + logDebug(s"Committing batch $batchId to $this") + outputMode match { + case Append | Update => + val rows = AddedData(batchId, newRows) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, newRows) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } + } else { + logDebug(s"Skipping already committed batch: $batchId") + } + } + + def clear(): Unit = synchronized { + batches.clear() + } + + override def toString(): String = "MemorySink" +} + +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} + +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) + extends DataSourceV2Writer with Logging { + + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + + def commit(messages: Array[WriterCommitMessage]): Unit = { + val newRows = messages.flatMap { + case message: MemoryWriterCommitMessage => message.data + } + sink.write(batchId, outputMode, newRows) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + // Don't accept any of the new input. + } +} + +class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) + extends ContinuousWriter { + + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + val newRows = messages.flatMap { + case message: MemoryWriterCommitMessage => message.data + } + sink.write(epochId, outputMode, newRows) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + // Don't accept any of the new input. + } +} + +case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { + def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + new MemoryDataWriter(partitionId, outputMode) + } +} + +class MemoryDataWriter(partition: Int, outputMode: OutputMode) + extends DataWriter[Row] with Logging { + + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = { + data.append(row) + } + + override def commit(): MemoryWriterCommitMessage = { + val msg = MemoryWriterCommitMessage(partition, data.clone()) + data.clear() + msg + } + + override def abort(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 0b22cbc46e6bf..440cae016a173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala new file mode 100644 index 0000000000000..be4b490754986 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.streaming.{OutputMode, StreamTest} + +class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { + test("data writer") { + val partition = 1234 + val writer = new MemoryDataWriter(partition, OutputMode.Append()) + writer.write(Row(1)) + writer.write(Row(2)) + writer.write(Row(44)) + val msg = writer.commit() + assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) + assert(msg.partition == partition) + + // Buffer should be cleared, so repeated commits should give empty. + assert(writer.commit().data.isEmpty) + } + + test("continuous writer") { + val sink = new MemorySinkV2 + val writer = new ContinuousMemoryWriter(sink, OutputMode.Append()) + writer.commit(0, + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + )) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + writer.commit(19, + Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) + )) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) + } + + test("microbatch writer") { + val sink = new MemorySinkV2 + new MemoryWriter(sink, 0, OutputMode.Append()).commit( + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + )) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + new MemoryWriter(sink, 19, OutputMode.Append()).commit( + Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) + )) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index e6cdc063c4e9f..4868ba4e68934 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index 03d0f63fa4d7f..ceba27b26e578 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala new file mode 100644 index 0000000000000..ef801ceb1310c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.streaming.StreamTest + +class RateSourceV2Suite extends StreamTest { + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[RateStreamV2Reader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("microbatch - numPartitions propagated") { + val reader = new RateStreamV2Reader( + new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + val tasks = reader.createReadTasks() + assert(tasks.size == 11) + } + + test("microbatch - set offset") { + val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) + val startOffset = RateStreamOffset(Map((0, (0, 1000)))) + val endOffset = RateStreamOffset(Map((0, (0, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val reader = new RateStreamV2Reader( + new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: RateStreamOffset => + assert(r.partitionToValueAndRunTimeMs(0)._2 == reader.creationTimeMs) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: RateStreamOffset => + // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted + // longer than 100ms. It should never be early. + assert(r.partitionToValueAndRunTimeMs(0)._1 >= 9) + assert(r.partitionToValueAndRunTimeMs(0)._2 >= reader.creationTimeMs + 100) + + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val reader = new RateStreamV2Reader( + new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + val startOffset = RateStreamOffset(Map((0, (0, 1000)))) + val endOffset = RateStreamOffset(Map((0, (20, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createReadTasks() + assert(tasks.size == 1) + assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) + } + + test("microbatch - data read") { + val reader = new RateStreamV2Reader( + new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) + val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { + case (part, (currentVal, currentReadTime)) => + (part, (currentVal + 33, currentReadTime + 1000)) + }.toMap) + + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createReadTasks() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[ContinuousRateStreamReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("continuous data") { + val reader = new ContinuousRateStreamReader( + new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setOffset(Optional.empty()) + val tasks = reader.createReadTasks() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamReadTask => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + ._2 + val r = t.createDataReader().asInstanceOf[RateStreamDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + ContinuousRateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b6baaed1927e4..7a2d9e3728208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala index f208f9bd9b6e3..429748261f1ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, SerializedOffset} +import org.apache.spark.sql.execution.streaming.{LongOffset, SerializedOffset} +import org.apache.spark.sql.sources.v2.reader.Offset trait OffsetSuite extends SparkFunSuite { /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 3d687d2214e90..8163a1f91e1ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreCon import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index dc5b998ad68b5..7a1ff89f2f3b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1b4d8556f6ae5..fa0313592b8e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288fb..fc9ac2a56c4e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index cc693909270f8..f813b77e3ce6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..952908f21ca60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _} import org.apache.spark.sql.streaming.Trigger._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala index 19ab2ff13e14e..9a35f097e6e40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.streaming.util import java.util.concurrent.CountDownLatch import org.apache.spark.sql.{SQLContext, _} -import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source} +import org.apache.spark.sql.execution.streaming.{LongOffset, Sink, Source} import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, StructField, StructType} From c3dd2a26deaadf508b4e163eab2c0544cd922540 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 13 Dec 2017 22:46:20 -0800 Subject: [PATCH 426/712] [SPARK-22779][SQL] Resolve default values for fallback configs. SQLConf allows some callers to define a custom default value for configs, and that complicates a little bit the handling of fallback config entries, since most of the default value resolution is hidden by the config code. This change peaks into the internals of these fallback configs to figure out the correct default value, and also returns the current human-readable default when showing the default value (e.g. through "set -v"). Author: Marcelo Vanzin Closes #19974 from vanzin/SPARK-22779. --- .../spark/internal/config/ConfigEntry.scala | 8 +++-- .../apache/spark/sql/internal/SQLConf.scala | 16 +++++++--- .../spark/sql/internal/SQLConfSuite.scala | 30 +++++++++++++++++++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index f1190289244e9..ede3ace4f9aac 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -139,7 +139,7 @@ private[spark] class OptionalConfigEntry[T]( s => Some(rawValueConverter(s)), v => v.map(rawStringConverter).orNull, doc, isPublic) { - override def defaultValueString: String = "" + override def defaultValueString: String = ConfigEntry.UNDEFINED override def readFrom(reader: ConfigReader): Option[T] = { readString(reader).map(rawValueConverter) @@ -149,12 +149,12 @@ private[spark] class OptionalConfigEntry[T]( /** * A config entry whose default value is defined by another config entry. */ -private class FallbackConfigEntry[T] ( +private[spark] class FallbackConfigEntry[T] ( key: String, alternatives: List[String], doc: String, isPublic: Boolean, - private[config] val fallback: ConfigEntry[T]) + val fallback: ConfigEntry[T]) extends ConfigEntry[T](key, alternatives, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { @@ -167,6 +167,8 @@ private class FallbackConfigEntry[T] ( private[spark] object ConfigEntry { + val UNDEFINED = "" + private val knownConfigs = new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]() def registerEntry(entry: ConfigEntry[_]): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1121444cc938a..cf7e3ebce7411 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1379,7 +1379,7 @@ class SQLConf extends Serializable with Logging { Option(settings.get(key)). orElse { // Try to use the default value - Option(sqlConfEntries.get(key)).map(_.defaultValueString) + Option(sqlConfEntries.get(key)).map { e => e.stringConverter(e.readFrom(reader)) } }. getOrElse(throw new NoSuchElementException(key)) } @@ -1417,14 +1417,21 @@ class SQLConf extends Serializable with Logging { * not set yet, return `defaultValue`. */ def getConfString(key: String, defaultValue: String): String = { - if (defaultValue != null && defaultValue != "") { + if (defaultValue != null && defaultValue != ConfigEntry.UNDEFINED) { val entry = sqlConfEntries.get(key) if (entry != null) { // Only verify configs in the SQLConf object entry.valueConverter(defaultValue) } } - Option(settings.get(key)).getOrElse(defaultValue) + Option(settings.get(key)).getOrElse { + // If the key is not set, need to check whether the config entry is registered and is + // a fallback conf, so that we can check its parent. + sqlConfEntries.get(key) match { + case e: FallbackConfigEntry[_] => getConfString(e.fallback.key, defaultValue) + case _ => defaultValue + } + } } /** @@ -1440,7 +1447,8 @@ class SQLConf extends Serializable with Logging { */ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => - (entry.key, getConfString(entry.key, entry.defaultValueString), entry.doc) + val displayValue = Option(getConfString(entry.key, null)).getOrElse(entry.defaultValueString) + (entry.key, displayValue, entry.doc) }.toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 8b1521bacea49..c9a6975da6be8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -280,4 +280,34 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { spark.sessionState.conf.clear() } + + test("SPARK-22779: correctly compute default value for fallback configs") { + val fallback = SQLConf.buildConf("spark.sql.__test__.spark_22779") + .fallbackConf(SQLConf.PARQUET_COMPRESSION) + + assert(spark.sessionState.conf.getConfString(fallback.key) === + SQLConf.PARQUET_COMPRESSION.defaultValue.get) + assert(spark.sessionState.conf.getConfString(fallback.key, "lzo") === "lzo") + + val displayValue = spark.sessionState.conf.getAllDefinedConfs + .find { case (key, _, _) => key == fallback.key } + .map { case (_, v, _) => v } + .get + assert(displayValue === fallback.defaultValueString) + + spark.sessionState.conf.setConf(SQLConf.PARQUET_COMPRESSION, "gzip") + assert(spark.sessionState.conf.getConfString(fallback.key) === "gzip") + + spark.sessionState.conf.setConf(fallback, "lzo") + assert(spark.sessionState.conf.getConfString(fallback.key) === "lzo") + + val newDisplayValue = spark.sessionState.conf.getAllDefinedConfs + .find { case (key, _, _) => key == fallback.key } + .map { case (_, v, _) => v } + .get + assert(newDisplayValue === "lzo") + + SQLConf.unregister(fallback) + } + } From 7d8e2ca7f8667d809034073aad2ea67b3b082fc2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Dec 2017 19:33:54 +0800 Subject: [PATCH 427/712] [SPARK-22775][SQL] move dictionary related APIs from ColumnVector to WritableColumnVector ## What changes were proposed in this pull request? These dictionary related APIs are special to `WritableColumnVector` and should not be in `ColumnVector`, which will be public soon. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19970 from cloud-fan/final. --- .../parquet/VectorizedColumnReader.java | 2 +- .../vectorized/ArrowColumnVector.java | 5 -- .../execution/vectorized/ColumnVector.java | 66 ++++---------- .../vectorized/WritableColumnVector.java | 88 ++++++++++++------- 4 files changed, 73 insertions(+), 88 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index b7646969bcf3d..3ba180860c325 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -239,7 +239,7 @@ private void decodeDictionaryIds( int rowId, int num, WritableColumnVector column, - ColumnVector dictionaryIds) { + WritableColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: if (column.dataType() == DataTypes.IntegerType || diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 1f1347ccd315e..e99201f6372fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -159,11 +159,6 @@ public int[] getInts(int rowId, int count) { return array; } - @Override - public int getDictId(int rowId) { - throw new UnsupportedOperationException(); - } - // // APIs dealing with Longs // diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index e6b87519239dd..fd5caf3bf170b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -22,24 +22,22 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents a column of values and provides the main APIs to access the data - * values. It supports all the types and contains get APIs as well as their batched versions. - * The batched versions are preferable whenever possible. + * This class represents in-memory values of a column and provides the main APIs to access the data. + * It supports all the types and contains get APIs as well as their batched versions. The batched + * versions are considered to be faster and preferable whenever possible. * * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data is stored in the child columns and the parent column - * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. - * Lengths and offsets are encoded identically to INTs. + * columns have child columns. All of the data are stored in the child columns and the parent column + * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child + * column and are encoded identically to INTs. + * * Maps are just a special case of a two field struct. * * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current RowBatch. - * - * A ColumnVector should be considered immutable once originally created. - * - * ColumnVectors are intended to be reused. + * in the current batch. */ public abstract class ColumnVector implements AutoCloseable { + /** * Returns the data type of this column. */ @@ -47,7 +45,6 @@ public abstract class ColumnVector implements AutoCloseable { /** * Cleans up memory for this column. The column is not usable after this. - * TODO: this should probably have ref-counted semantics. */ public abstract void close(); @@ -107,13 +104,6 @@ public abstract class ColumnVector implements AutoCloseable { */ public abstract int[] getInts(int rowId, int count); - /** - * Returns the dictionary Id for rowId. - * This should only be called when the ColumnVector is dictionaryIds. - * We have this separate method for dictionaryIds as per SPARK-16928. - */ - public abstract int getDictId(int rowId); - /** * Returns the value for rowId. */ @@ -145,39 +135,39 @@ public abstract class ColumnVector implements AutoCloseable { public abstract double[] getDoubles(int rowId, int count); /** - * Returns the length of the array at rowid. + * Returns the length of the array for rowId. */ public abstract int getArrayLength(int rowId); /** - * Returns the offset of the array at rowid. + * Returns the offset of the array for rowId. */ public abstract int getArrayOffset(int rowId); /** - * Returns a utility object to get structs. + * Returns the struct for rowId. */ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } /** - * Returns a utility object to get structs. - * provided to keep API compatibility with InternalRow for code generation + * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark + * codegen framework, the second parameter is totally ignored. */ public final ColumnarRow getStruct(int rowId, int size) { return getStruct(rowId); } /** - * Returns the array at rowid. + * Returns the array for rowId. */ public final ColumnarArray getArray(int rowId) { return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } /** - * Returns the value for rowId. + * Returns the map for rowId. */ public MapData getMap(int ordinal) { throw new UnsupportedOperationException(); @@ -214,30 +204,6 @@ public MapData getMap(int ordinal) { */ protected DataType type; - /** - * The Dictionary for this column. - * - * If it's not null, will be used to decode the value in getXXX(). - */ - protected Dictionary dictionary; - - /** - * Reusable column for ids of dictionary. - */ - protected ColumnVector dictionaryIds; - - /** - * Returns true if this column has a dictionary. - */ - public boolean hasDictionary() { return this.dictionary != null; } - - /** - * Returns the underlying integer column for ids of dictionary. - */ - public ColumnVector getDictionaryIds() { - return dictionaryIds; - } - /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 7c053b579442c..63cf60818a855 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -36,8 +36,10 @@ * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), * the lengths are known up front. * - * A ColumnVector should be considered immutable once originally created. In other words, it is not - * valid to call put APIs after reads until reset() is called. + * A WritableColumnVector should be considered immutable once originally created. In other words, + * it is not valid to call put APIs after reads until reset() is called. + * + * WritableColumnVector are intended to be reused. */ public abstract class WritableColumnVector extends ColumnVector { @@ -105,6 +107,58 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { @Override public boolean anyNullsSet() { return anyNullsSet; } + /** + * Returns the dictionary Id for rowId. + * + * This should only be called when this `WritableColumnVector` represents dictionaryIds. + * We have this separate method for dictionaryIds as per SPARK-16928. + */ + public abstract int getDictId(int rowId); + + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected WritableColumnVector dictionaryIds; + + /** + * Returns true if this column has a dictionary. + */ + public boolean hasDictionary() { return this.dictionary != null; } + + /** + * Returns the underlying integer column for ids of dictionary. + */ + public WritableColumnVector getDictionaryIds() { + return dictionaryIds; + } + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public WritableColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Ensures that there is enough storage to store capacity elements. That is, the put() APIs * must work for all rowIds < capacity. @@ -613,36 +667,6 @@ public final int appendStruct(boolean isNull) { */ protected WritableColumnVector[] childColumns; - /** - * Update the dictionary. - */ - public void setDictionary(Dictionary dictionary) { - this.dictionary = dictionary; - } - - /** - * Reserve a integer column for ids of dictionary. - */ - public WritableColumnVector reserveDictionaryIds(int capacity) { - WritableColumnVector dictionaryIds = (WritableColumnVector) this.dictionaryIds; - if (dictionaryIds == null) { - dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); - this.dictionaryIds = dictionaryIds; - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(capacity); - } - return dictionaryIds; - } - - /** - * Returns the underlying integer column for ids of dictionary. - */ - @Override - public WritableColumnVector getDictionaryIds() { - return (WritableColumnVector) dictionaryIds; - } - /** * Reserve a new column. */ From d0957954398658b025fdac3b9a07b477e79520b9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Dec 2017 00:29:44 +0800 Subject: [PATCH 428/712] [SPARK-22785][SQL] remove ColumnVector.anyNullsSet ## What changes were proposed in this pull request? `ColumnVector.anyNullsSet` is not called anywhere except tests, and we can easily replace it with `ColumnVector.numNulls > 0` ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19980 from cloud-fan/minor. --- .../execution/vectorized/ArrowColumnVector.java | 5 ----- .../sql/execution/vectorized/ColumnVector.java | 6 ------ .../execution/vectorized/OffHeapColumnVector.java | 4 +--- .../execution/vectorized/OnHeapColumnVector.java | 4 +--- .../execution/vectorized/WritableColumnVector.java | 14 ++------------ .../vectorized/ArrowColumnVectorSuite.scala | 11 ----------- .../execution/vectorized/ColumnarBatchSuite.scala | 9 --------- 7 files changed, 4 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index e99201f6372fe..d53e1fcab0c5a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -54,11 +54,6 @@ public int numNulls() { return accessor.getNullCount(); } - @Override - public boolean anyNullsSet() { - return numNulls() > 0; - } - @Override public void close() { if (childColumns != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index fd5caf3bf170b..dc7c1269bedd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -53,12 +53,6 @@ public abstract class ColumnVector implements AutoCloseable { */ public abstract int numNulls(); - /** - * Returns true if any of the nulls indicator are set for this column. This can be used - * as an optimization to prevent setting nulls. - */ - public abstract boolean anyNullsSet(); - /** * Returns whether the value at rowId is NULL. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 5f1b9885334b7..1c45b846790b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -110,7 +110,6 @@ public void putNotNull(int rowId) { public void putNull(int rowId) { Platform.putByte(null, nulls + rowId, (byte) 1); ++numNulls; - anyNullsSet = true; } @Override @@ -119,13 +118,12 @@ public void putNulls(int rowId, int count) { for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 1); } - anyNullsSet = true; numNulls += count; } @Override public void putNotNulls(int rowId, int count) { - if (!anyNullsSet) return; + if (numNulls == 0) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index f12772ede575d..1d538fe4181b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -107,7 +107,6 @@ public void putNotNull(int rowId) { public void putNull(int rowId) { nulls[rowId] = (byte)1; ++numNulls; - anyNullsSet = true; } @Override @@ -115,13 +114,12 @@ public void putNulls(int rowId, int count) { for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)1; } - anyNullsSet = true; numNulls += count; } @Override public void putNotNulls(int rowId, int count) { - if (!anyNullsSet) return; + if (numNulls == 0) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 63cf60818a855..5f6f125976e12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -54,12 +54,11 @@ public void reset() { ((WritableColumnVector) c).reset(); } } - numNulls = 0; elementsAppended = 0; - if (anyNullsSet) { + if (numNulls > 0) { putNotNulls(0, capacity); - anyNullsSet = false; } + numNulls = 0; } @Override @@ -104,9 +103,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { @Override public int numNulls() { return numNulls; } - @Override - public boolean anyNullsSet() { return anyNullsSet; } - /** * Returns the dictionary Id for rowId. * @@ -640,12 +636,6 @@ public final int appendStruct(boolean isNull) { */ protected int numNulls; - /** - * True if there is at least one NULL byte set. This is an optimization for the writer, to skip - * having to clear NULL bits. - */ - protected boolean anyNullsSet; - /** * True if this column's values are fixed. This means the column values never change, even * across resets. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 068a17bf772e1..e460d0721e7bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -42,7 +42,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -71,7 +70,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -100,7 +98,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -129,7 +126,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -158,7 +154,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -187,7 +182,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -216,7 +210,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -246,7 +239,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -274,7 +266,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -319,7 +310,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -383,7 +373,6 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) - assert(columnVector.anyNullsSet) assert(columnVector.numNulls === 1) val row0 = columnVector.getStruct(0, 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index c9c6bee513b53..d3ed8276b8f10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -65,27 +65,22 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 - assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNull() reference += false - assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) - assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNull() reference += true - assert(column.anyNullsSet()) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) - assert(column.anyNullsSet()) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -93,13 +88,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 - assert(column.anyNullsSet()) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 - assert(column.anyNullsSet()) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -107,7 +100,6 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 - assert(column.anyNullsSet()) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -116,7 +108,6 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 - assert(column.anyNullsSet()) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => From 606ae491e41017c117301aeccfdf7221adec7f23 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 15 Dec 2017 02:14:08 +0800 Subject: [PATCH 429/712] [SPARK-22774][SQL][TEST] Add compilation check into TPCDSQuerySuite ## What changes were proposed in this pull request? This PR adds check whether Java code generated by Catalyst can be compiled by `janino` correctly or not into `TPCDSQuerySuite`. Before this PR, this suite only checks whether analysis can be performed correctly or not. This check will be able to avoid unexpected performance degrade by interpreter execution due to a Java compilation error. ## How was this patch tested? Existing a test case, but updated it. Author: Kazuaki Ishizaki Closes #19971 from kiszk/SPARK-22774. --- .../apache/spark/sql/TPCDSQuerySuite.scala | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index a58000da1543d..dd427a5c52edf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -348,13 +350,41 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + private def checkGeneratedCode(plan: SparkPlan): Unit = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => s + } + codegenSubtrees.toSeq.foreach { subtree => + val code = subtree.doCodeGen()._2 + try { + // Just check the generated code can be properly compiled + CodeGenerator.compile(code) + } catch { + case e: Exception => + val msg = + s""" + |failed to compile: + |Subtree: + |$subtree + |Generated code: + |${CodeFormatter.format(code)} + """.stripMargin + throw new Exception(msg, e) + } + } + } + tpcdsQueries.foreach { name => val queryString = resourceToString(s"tpcds/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) test(name) { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - // Just check the plans can be properly generated - sql(queryString).queryExecution.executedPlan + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) } } } @@ -368,8 +398,9 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) test(s"modified-$name") { - // Just check the plans can be properly generated - sql(queryString).queryExecution.executedPlan + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) } } } From 40de176c93c5aa05bcbb1328721118b6b46ba51d Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 14 Dec 2017 11:19:34 -0800 Subject: [PATCH 430/712] [SPARK-16496][SQL] Add wholetext as option for reading text in SQL. ## What changes were proposed in this pull request? In multiple text analysis problems, it is not often desirable for the rows to be split by "\n". There exists a wholeText reader for RDD API, and this JIRA just adds the same support for Dataset API. ## How was this patch tested? Added relevant new tests for both scala and Java APIs Author: Prashant Sharma Author: Prashant Sharma Closes #14151 from ScrapCodes/SPARK-16496/wholetext. --- python/pyspark/sql/readwriter.py | 7 +- .../apache/spark/sql/DataFrameReader.scala | 16 ++- .../HadoopFileWholeTextReader.scala | 57 +++++++++ .../datasources/text/TextFileFormat.scala | 31 ++++- .../datasources/text/TextOptions.scala | 7 ++ .../datasources/text/TextSuite.scala | 5 +- .../datasources/text/WholeTextFileSuite.scala | 108 ++++++++++++++++++ 7 files changed, 221 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 1ad974e9aa4c7..4e58bfb843644 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -304,7 +304,7 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths): + def text(self, paths, wholetext=False): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there @@ -313,11 +313,16 @@ def text(self, paths): Each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] + >>> df = spark.read.text('python/test_support/sql/text-test.txt', wholetext=True) + >>> df.collect() + [Row(value=u'hello\\nthis')] """ + self._set_opts(wholetext=wholetext) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ea1cf66775235..39fec8f983b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -646,7 +646,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * You can set the following text-specific option(s) for reading text files: + *
      + *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    + * By default, each line in the text files is a new row in the resulting DataFrame. + * + * Usage example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -678,7 +685,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text files is a new element in the resulting Dataset. For example: + * You can set the following textFile-specific option(s) for reading text files: + *
      + *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.textFile("/path/to/spark/README.md") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala new file mode 100644 index 0000000000000..c61a89e6e8c3f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.Closeable +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.input.WholeTextFileRecordReader + +/** + * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which is all of the lines + * in that file. + */ +class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration) + extends Iterator[Text] with Closeable { + private val iterator = { + val fileSplit = new CombineFileSplit( + Array(new Path(new URI(file.filePath))), + Array(file.start), + Array(file.length), + // TODO: Implement Locality + Array.empty[String]) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val reader = new WholeTextFileRecordReader(fileSplit, hadoopAttemptContext, 0) + reader.initialize(fileSplit, hadoopAttemptContext) + new RecordReaderIterator(reader) + } + + override def hasNext: Boolean = iterator.hasNext + + override def next(): Text = iterator.next() + + override def close(): Unit = iterator.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index d0690445d7672..8a6ab303fc0f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.datasources.text +import java.io.Closeable + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} @@ -53,6 +57,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val textOptions = new TextOptions(options) + super.isSplitable(sparkSession, options, path) && !textOptions.wholeText + } + override def inferSchema( sparkSession: SparkSession, options: Map[String, String], @@ -97,14 +109,25 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { assert( requiredSchema.length <= 1, "Text data source only produces a single data column named \"value\".") - + val textOptions = new TextOptions(options) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + } + + private def readToUnsafeMem(conf: Broadcast[SerializableConfiguration], + requiredSchema: StructType, wholeTextMode: Boolean): + (PartitionedFile) => Iterator[UnsafeRow] = { + (file: PartitionedFile) => { - val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val confValue = conf.value.value + val reader = if (!wholeTextMode) { + new HadoopFileLinesReader(file, confValue) + } else { + new HadoopFileWholeTextReader(file, confValue) + } Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) - if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 49bd7382f9cf3..2a661561ab51e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -33,8 +33,15 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti * Compression codec to use. */ val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) + + /** + * wholetext - If true, read a file as a single row and not split by "\n". + */ + val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + } private[text] object TextOptions { val COMPRESSION = "compression" + val WHOLETEXT = "wholetext" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index cb7393cdd2b9d..33287044f279e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -185,10 +185,9 @@ class TextSuite extends QueryTest with SharedSQLContext { val data = df.collect() assert(data(0) == Row("This is a test file for the text data source")) assert(data(1) == Row("1+1")) - // non ascii characters are not allowed in the code, so we disable the scalastyle here. - // scalastyle:off + // scalastyle:off nonascii assert(data(2) == Row("数据砖头")) - // scalastyle:on + // scalastyle:on nonascii assert(data(3) == Row("\"doh\"")) assert(data.length == 4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala new file mode 100644 index 0000000000000..8bd736bee69de --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} + +class WholeTextFileSuite extends QueryTest with SharedSQLContext { + + // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which + // can cause Filesystem.get(Configuration) to return a cached instance created with a different + // configuration than the one passed to get() (see HADOOP-8490 for more details). This caused + // hard-to-reproduce test failures, since any suites that were run after this one would inherit + // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, + // we disable FileSystem caching in this suite. + protected override def sparkConf = + super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") + + private def testFile: String = { + Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString + } + + test("reading text file with option wholetext=true") { + val df = spark.read.option("wholetext", "true") + .format("text").load(testFile) + // schema + assert(df.schema == new StructType().add("value", StringType)) + + // verify content + val data = df.collect() + assert(data(0) == + Row( + // scalastyle:off nonascii + """This is a test file for the text data source + |1+1 + |数据砖头 + |"doh" + |""".stripMargin)) + // scalastyle:on nonascii + assert(data.length == 1) + } + + test("correctness of wholetext option") { + import org.apache.spark.sql.catalyst.util._ + withTempDir { dir => + val file1 = new File(dir, "text1.txt") + stringToFile(file1, + """text file 1 contents. + |From: None to: ?? + """.stripMargin) + val file2 = new File(dir, "text2.txt") + stringToFile(file2, "text file 2 contents.") + val file3 = new File(dir, "text3.txt") + stringToFile(file3, "text file 3 contents.") + val df = spark.read.option("wholetext", "true").text(dir.getAbsolutePath) + // Since wholetext option reads each file into a single row, df.length should be no. of files. + val data = df.sort("value").collect() + assert(data.length == 3) + // Each files should represent a single Row/element in Dataframe/Dataset + assert(data(0) == Row( + """text file 1 contents. + |From: None to: ?? + """.stripMargin)) + assert(data(1) == Row( + """text file 2 contents.""".stripMargin)) + assert(data(2) == Row( + """text file 3 contents.""".stripMargin)) + } + } + + + test("Correctness of wholetext option with gzip compression mode.") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df1 = spark.range(0, 1000).selectExpr("CAST(id AS STRING) AS s").repartition(1) + df1.write.option("compression", "gzip").mode("overwrite").text(path) + // On reading through wholetext mode, one file will be read as a single row, i.e. not + // delimited by "next line" character. + val expected = Row(Range(0, 1000).mkString("", "\n", "\n")) + Seq(10, 100, 1000).foreach { bytes => + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { + val df2 = spark.read.option("wholetext", "true").format("text").load(path) + val result = df2.collect().head + assert(result === expected) + } + } + } + } +} From 6d99940397136e4ed0764f83f442bcffcb20d6e7 Mon Sep 17 00:00:00 2001 From: kellyzly Date: Thu, 14 Dec 2017 13:39:19 -0600 Subject: [PATCH 431/712] [SPARK-22660][BUILD] Use position() and limit() to fix ambiguity issue in scala-2.12 ## What changes were proposed in this pull request? Missing some changes about limit in TaskSetManager.scala ## How was this patch tested? running tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: kellyzly Closes #19976 from kellyzly/SPARK-22660.2. --- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index de4711f461df2..c3ed11bfe352a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -488,11 +488,11 @@ private[spark] class TaskSetManager( abort(s"$msg Exception during serialization: $e") throw new TaskNotSerializableException(e) } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && + if (serializedTask.limit() > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"(${serializedTask.limit() / 1024} KB). The maximum recommended task size is " + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") } addRunningTask(taskId) @@ -502,7 +502,7 @@ private[spark] class TaskSetManager( // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + - s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") + s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit()} bytes)") sched.dagScheduler.taskStarted(task, info) new TaskDescription( From 2fe16333d59cd8558afca3916821e1ea7e98d1bc Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 14 Dec 2017 14:03:08 -0800 Subject: [PATCH 432/712] [SPARK-22778][KUBERNETES] Added the missing service metadata for KubernetesClusterManager ## What changes were proposed in this pull request? This PR added the missing service metadata for `KubernetesClusterManager`. Without the metadata, the service loader couldn't load `KubernetesClusterManager`, and caused the driver to fail to create a `ExternalClusterManager`, as being reported in SPARK-22778. The PR also changed the `k8s:` prefix used to `k8s://`, which is what existing Spark on k8s users are familiar and used to. ## How was this patch tested? Manual testing verified that the fix resolved the issue in SPARK-22778. /cc vanzin felixcheung jiangxb1987 Author: Yinan Li Closes #19972 from liyinan926/fix-22778. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 +++--- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../src/test/scala/org/apache/spark/util/UtilsSuite.scala | 8 ++++---- .../org.apache.spark.scheduler.ExternalClusterManager | 1 + .../deploy/k8s/submit/KubernetesClientApplication.scala | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe5b4ea24440b..8871870eb8681 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2757,7 +2757,7 @@ private[spark] object Utils extends Logging { /** * Check the validity of the given Kubernetes master URL and return the resolved URL. Prefix - * "k8s:" is appended to the resolved URL as the prefix is used by KubernetesClusterManager + * "k8s://" is appended to the resolved URL as the prefix is used by KubernetesClusterManager * in canCreate to determine if the KubernetesClusterManager should be used. */ def checkAndGetK8sMasterUrl(rawMasterURL: String): String = { @@ -2770,7 +2770,7 @@ private[spark] object Utils extends Logging { val resolvedURL = s"https://$masterWithoutK8sPrefix" logInfo("No scheme specified for kubernetes master URL, so defaulting to https. Resolved " + s"URL is $resolvedURL.") - return s"k8s:$resolvedURL" + return s"k8s://$resolvedURL" } val masterScheme = new URI(masterWithoutK8sPrefix).getScheme @@ -2789,7 +2789,7 @@ private[spark] object Utils extends Logging { throw new IllegalArgumentException("Invalid Kubernetes master scheme: " + masterScheme) } - return s"k8s:$resolvedURL" + s"k8s://$resolvedURL" } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 35594ec47c941..2eb8a1fee104c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -408,7 +408,7 @@ class SparkSubmitSuite childArgsMap.get("--arg") should be (Some("arg1")) mainClass should be (KUBERNETES_CLUSTER_SUBMIT_CLASS) classpath should have length (0) - conf.get("spark.master") should be ("k8s:https://host:port") + conf.get("spark.master") should be ("k8s://https://host:port") conf.get("spark.executor.memory") should be ("5g") conf.get("spark.driver.memory") should be ("4g") conf.get("spark.kubernetes.namespace") should be ("spark") diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5c4e4ca0cded6..eaea6b030c154 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1148,16 +1148,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("check Kubernetes master URL") { val k8sMasterURLHttps = Utils.checkAndGetK8sMasterUrl("k8s://https://host:port") - assert(k8sMasterURLHttps === "k8s:https://host:port") + assert(k8sMasterURLHttps === "k8s://https://host:port") val k8sMasterURLHttp = Utils.checkAndGetK8sMasterUrl("k8s://http://host:port") - assert(k8sMasterURLHttp === "k8s:http://host:port") + assert(k8sMasterURLHttp === "k8s://http://host:port") val k8sMasterURLWithoutScheme = Utils.checkAndGetK8sMasterUrl("k8s://127.0.0.1:8443") - assert(k8sMasterURLWithoutScheme === "k8s:https://127.0.0.1:8443") + assert(k8sMasterURLWithoutScheme === "k8s://https://127.0.0.1:8443") val k8sMasterURLWithoutScheme2 = Utils.checkAndGetK8sMasterUrl("k8s://127.0.0.1") - assert(k8sMasterURLWithoutScheme2 === "k8s:https://127.0.0.1") + assert(k8sMasterURLWithoutScheme2 === "k8s://https://127.0.0.1") intercept[IllegalArgumentException] { Utils.checkAndGetK8sMasterUrl("k8s:https://host:port") diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager new file mode 100644 index 0000000000000..81d14766ffb8d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -0,0 +1 @@ +org.apache.spark.scheduler.cluster.k8s.KubernetesClusterManager diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 4d17608c602d8..240a1144577b0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -203,8 +203,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // The master URL has been checked for validity already in SparkSubmit. - // We just need to get rid of the "k8s:" prefix here. - val master = sparkConf.get("spark.master").substring("k8s:".length) + // We just need to get rid of the "k8s://" prefix here. + val master = sparkConf.get("spark.master").substring("k8s://".length) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None val loggingPodStatusWatcher = new LoggingPodStatusWatcherImpl( From 59daf91b7cfb50b1c20eb41959921fc03103b739 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 14 Dec 2017 14:31:21 -0800 Subject: [PATCH 433/712] [SPARK-22733] Split StreamExecution into MicroBatchExecution and StreamExecution. ## What changes were proposed in this pull request? StreamExecution is now an abstract base class, which MicroBatchExecution (the current StreamExecution) inherits. When continuous processing is implemented, we'll have a new ContinuousExecution implementation of StreamExecution. A few fields are also renamed to make them less microbatch-specific. ## How was this patch tested? refactoring only Author: Jose Torres Closes #19926 from joseph-torres/continuous-refactor. --- .../{BatchCommitLog.scala => CommitLog.scala} | 8 +- .../streaming/MicroBatchExecution.scala | 407 ++++++++++++++++ .../execution/streaming/StreamExecution.scala | 457 ++---------------- .../sql/streaming/StreamingQueryManager.scala | 2 +- .../streaming/EventTimeWatermarkSuite.scala | 4 +- .../sql/streaming/FileStreamSourceSuite.scala | 5 +- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 20 +- .../streaming/StreamingAggregationSuite.scala | 4 +- .../sql/streaming/StreamingQuerySuite.scala | 4 +- 10 files changed, 484 insertions(+), 429 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{BatchCommitLog.scala => CommitLog.scala} (93%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 5e24e8fc4e3cc..5b114242558dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -42,10 +42,10 @@ import org.apache.spark.sql.SparkSession * line 1: version * line 2: metadata (optional json string) */ -class BatchCommitLog(sparkSession: SparkSession, path: String) +class CommitLog(sparkSession: SparkSession, path: String) extends HDFSMetadataLog[String](sparkSession, path) { - import BatchCommitLog._ + import CommitLog._ def add(batchId: Long): Unit = { super.add(batchId, EMPTY_JSON) @@ -53,7 +53,7 @@ class BatchCommitLog(sparkSession: SparkSession, path: String) override def add(batchId: Long, metadata: String): Boolean = { throw new UnsupportedOperationException( - "BatchCommitLog does not take any metadata, use 'add(batchId)' instead") + "CommitLog does not take any metadata, use 'add(batchId)' instead") } override protected def deserialize(in: InputStream): String = { @@ -76,7 +76,7 @@ class BatchCommitLog(sparkSession: SparkSession, path: String) } } -object BatchCommitLog { +object CommitLog { private val VERSION = 1 private val EMPTY_JSON = "{}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala new file mode 100644 index 0000000000000..a67dda99dc01b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} + +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.Offset +import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.util.{Clock, Utils} + +class MicroBatchExecution( + sparkSession: SparkSession, + name: String, + checkpointRoot: String, + analyzedPlan: LogicalPlan, + sink: Sink, + trigger: Trigger, + triggerClock: Clock, + outputMode: OutputMode, + deleteCheckpointOnStop: Boolean) + extends StreamExecution( + sparkSession, name, checkpointRoot, analyzedPlan, sink, + trigger, triggerClock, outputMode, deleteCheckpointOnStop) { + + private val triggerExecutor = trigger match { + case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) + case OneTimeTrigger => OneTimeExecutor() + case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") + } + + override lazy val logicalPlan: LogicalPlan = { + assert(queryExecutionThread eq Thread.currentThread, + "logicalPlan must be initialized in QueryExecutionThread " + + s"but the current thread was ${Thread.currentThread}") + var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() + val _logicalPlan = analyzedPlan.transform { + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output)(sparkSession) + }) + } + sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + uniqueSources = sources.distinct + _logicalPlan + } + + /** + * Repeatedly attempts to run batches as data arrives. + */ + protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { + triggerExecutor.execute(() => { + startTrigger() + + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionForStream) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() + } + if (dataAvailable) { + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionForStream) + } + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + } else { + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) + } + } + updateStatusMessage("Waiting for next trigger") + isActive + }) + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). This function must be called + * before any processing occurs and will populate the following fields: + * - currentBatchId + * - committedOffsets + * - availableOffsets + * The basic structure of this method is as follows: + * + * Identify (from the offset log) the offsets used to run the last batch + * IF last batch exists THEN + * Set the next batch to be executed as the last recovered batch + * Check the commit log to see which batch was committed last + * IF the last batch was committed THEN + * Call getBatch using the last batch start and end offsets + * // ^^^^ above line is needed since some sources assume last batch always re-executes + * Setup for a new batch i.e., start = last batch end, and identify new end + * DONE + * ELSE + * Identify a brand new batch + * DONE + */ + private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { + offsetLog.getLatest() match { + case Some((latestBatchId, nextOffsets)) => + /* First assume that we are re-executing the latest known batch + * in the offset log */ + currentBatchId = latestBatchId + availableOffsets = nextOffsets.toStreamProgress(sources) + /* Initialize committed offsets to a committed batch, which at this + * is the second latest batch id in the offset log. */ + if (latestBatchId != 0) { + val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse { + throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist") + } + committedOffsets = secondLatestBatchId.toStreamProgress(sources) + } + + // update offset metadata + nextOffsets.metadata.foreach { metadata => + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) + offsetSeqMetadata = OffsetSeqMetadata( + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + } + + /* identify the current batch id: if commit log indicates we successfully processed the + * latest batch id in the offset log, then we can safely move to the next batch + * i.e., committedBatchId + 1 */ + commitLog.getLatest() match { + case Some((latestCommittedBatchId, _)) => + if (latestBatchId == latestCommittedBatchId) { + /* The last batch was successfully committed, so we can safely process a + * new next batch but first: + * Make a call to getBatch using the offsets from previous batch. + * because certain sources (e.g., KafkaSource) assume on restart the last + * batch will be executed before getOffset is called again. */ + availableOffsets.foreach { ao: (Source, Offset) => + val (source, end) = ao + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + } + currentBatchId = latestCommittedBatchId + 1 + committedOffsets ++= availableOffsets + // Construct a new batch be recomputing availableOffsets + constructNextBatch() + } else if (latestCommittedBatchId < latestBatchId - 1) { + logWarning(s"Batch completion log latest batch id is " + + s"${latestCommittedBatchId}, which is not trailing " + + s"batchid $latestBatchId by one") + } + case None => logInfo("no commit log present") + } + logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + s"$committedOffsets and available offsets $availableOffsets") + case None => // We are starting this stream for the first time. + logInfo(s"Starting new streaming query.") + currentBatchId = 0 + constructNextBatch() + } + } + + /** + * Returns true if there is any new data available to be processed. + */ + private def dataAvailable: Boolean = { + availableOffsets.exists { + case (source, available) => + committedOffsets + .get(source) + .map(committed => committed != available) + .getOrElse(true) + } + } + + /** + * Queries all of the sources to see if any new data is available. When there is new data the + * batchId counter is incremented and a new log entry is written with the newest offsets. + */ + private def constructNextBatch(): Unit = { + // Check to see what new data is available. + val hasNewData = { + awaitProgressLock.lock() + try { + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) + } + }.toMap + availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + + if (dataAvailable) { + true + } else { + noNewData = true + false + } + } finally { + awaitProgressLock.unlock() + } + } + if (hasNewData) { + var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs + // Update the eventTime watermarks if we find any in the plan. + if (lastExecution != null) { + lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec => e + }.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = watermarkMsMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + watermarkMsMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!watermarkMsMap.isDefinedAt(index)) { + watermarkMsMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + if(!watermarkMsMap.isEmpty) { + val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 + if (newWatermarkMs > batchWatermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + batchWatermarkMs = newWatermarkMs + } else { + logDebug( + s"Event time didn't move: $newWatermarkMs < " + + s"$batchWatermarkMs") + } + } + } + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = batchWatermarkMs, + batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + + updateStatusMessage("Writing offsets to log") + reportTimeTaken("walCommit") { + assert(offsetLog.add( + currentBatchId, + availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + logInfo(s"Committed offsets for batch $currentBatchId. " + + s"Metadata ${offsetSeqMetadata.toString}") + + // NOTE: The following code is correct because runStream() processes exactly one + // batch at a time. If we add pipeline parallelism (multiple batches in flight at + // the same time), this cleanup logic will need to change. + + // Now that we've updated the scheduler's persistent checkpoint, it is safe for the + // sources to discard data from the previous batch. + if (currentBatchId != 0) { + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } else { + throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + } + } + + // It is now safe to discard the metadata beyond the minimum number to retain. + // Note that purge is exclusive, i.e. it purges everything before the target ID. + if (minLogEntriesToMaintain < currentBatchId) { + offsetLog.purge(currentBatchId - minLogEntriesToMaintain) + commitLog.purge(currentBatchId - minLogEntriesToMaintain) + } + } + } else { + awaitProgressLock.lock() + try { + // Wake up any threads that are waiting for the stream to progress. + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() + } + } + } + + /** + * Processes any data available between `availableOffsets` and `committedOffsets`. + * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. + */ + private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + // Request unprocessed data from all sources. + newData = reportTimeTaken("getBatch") { + availableOffsets.flatMap { + case (source, available) + if committedOffsets.get(source).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(source) + val batch = source.getBatch(current, available) + assert(batch.isStreaming, + s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + + s"${batch.queryExecution.logical}") + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + } + } + + // A list of attributes that will need to be updated. + val replacements = new ArrayBuffer[(Attribute, Attribute)] + // Replace sources in the logical plan with data that has arrived since the last batch. + val withNewSources = logicalPlan transform { + case StreamingExecutionRelation(source, output) => + newData.get(source).map { data => + val newPlan = data.logicalPlan + assert(output.size == newPlan.output.size, + s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + + s"${Utils.truncatedString(newPlan.output, ",")}") + replacements ++= output.zip(newPlan.output) + newPlan + }.getOrElse { + LocalRelation(output, isStreaming = true) + } + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val triggerLogicalPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => + replacementMap(a).withMetadata(a.metadata) + case ct: CurrentTimestamp => + CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, + ct.dataType) + case cd: CurrentDate => + CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, + cd.dataType, cd.timeZoneId) + } + + reportTimeTaken("queryPlanning") { + lastExecution = new IncrementalExecution( + sparkSessionToRunBatch, + triggerLogicalPlan, + outputMode, + checkpointFile("state"), + runId, + currentBatchId, + offsetSeqMetadata) + lastExecution.executedPlan // Force the lazy generation of execution plan + } + + val nextBatch = + new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + + reportTimeTaken("addBatch") { + SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + sink.addBatch(currentBatchId, nextBatch) + } + } + + awaitProgressLock.lock() + try { + // Wake up any threads that are waiting for the stream to progress. + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 16063c02ce06b..7946889e85e37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -22,10 +22,9 @@ import java.nio.channels.ClosedByInterruptException import java.util.UUID import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.locks.{Condition, ReentrantLock} import scala.collection.mutable.{Map => MutableMap} -import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.util.concurrent.UncheckedExecutionException @@ -33,10 +32,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.Offset @@ -58,7 +55,7 @@ case object TERMINATED extends State * @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without * errors */ -class StreamExecution( +abstract class StreamExecution( override val sparkSession: SparkSession, override val name: String, private val checkpointRoot: String, @@ -72,16 +69,16 @@ class StreamExecution( import org.apache.spark.sql.streaming.StreamingQueryListener._ - private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay + protected val pollingDelayMs: Long = sparkSession.sessionState.conf.streamingPollingDelay - private val minBatchesToRetain = sparkSession.sessionState.conf.minBatchesToRetain - require(minBatchesToRetain > 0, "minBatchesToRetain has to be positive") + protected val minLogEntriesToMaintain: Int = sparkSession.sessionState.conf.minBatchesToRetain + require(minLogEntriesToMaintain > 0, "minBatchesToRetain has to be positive") /** * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation. */ - private val awaitBatchLock = new ReentrantLock(true) - private val awaitBatchLockCondition = awaitBatchLock.newCondition() + protected val awaitProgressLock = new ReentrantLock(true) + protected val awaitProgressLockCondition = awaitProgressLock.newCondition() private val initializationLatch = new CountDownLatch(1) private val startLatch = new CountDownLatch(1) @@ -90,9 +87,11 @@ class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - checkpointPath.makeQualified(fs.getUri(), fs.getWorkingDirectory()).toUri.toString + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString } + def logicalPlan: LogicalPlan + /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. @@ -160,36 +159,7 @@ class StreamExecution( /** * A list of unique sources in the query plan. This will be set when generating logical plan. */ - @volatile private var uniqueSources: Seq[Source] = Seq.empty - - override lazy val logicalPlan: LogicalPlan = { - assert(microBatchThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") - var nextSourceId = 0L - val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() - val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => - toExecutionRelationMap.getOrElseUpdate(streamingRelation, { - // Materialize source to avoid creating it in every batch - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output)(sparkSession) - }) - } - sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } - uniqueSources = sources.distinct - _logicalPlan - } - - private val triggerExecutor = trigger match { - case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) - case OneTimeTrigger => OneTimeExecutor() - case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") - } + @volatile protected var uniqueSources: Seq[Source] = Seq.empty /** Defines the internal state of execution */ private val state = new AtomicReference[State](INITIALIZING) @@ -215,13 +185,13 @@ class StreamExecution( * [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a * running `KafkaConsumer` may cause endless loop. */ - val microBatchThread = - new StreamExecutionThread(s"stream execution thread for $prettyIdString") { + val queryExecutionThread: QueryExecutionThread = + new QueryExecutionThread(s"stream execution thread for $prettyIdString") { override def run(): Unit = { // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread sparkSession.sparkContext.setCallSite(callSite) - runBatches() + runStream() } } @@ -238,7 +208,7 @@ class StreamExecution( * fully processed, and its output was committed to the sink, hence no need to process it again. * This is used (for instance) during restart, to help identify which batch to run next. */ - val batchCommitLog = new BatchCommitLog(sparkSession, checkpointFile("commits")) + val commitLog = new CommitLog(sparkSession, checkpointFile("commits")) /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING @@ -250,7 +220,7 @@ class StreamExecution( override def exception: Option[StreamingQueryException] = Option(streamDeathCause) /** Returns the path of a file with `name` in the checkpoint directory. */ - private def checkpointFile(name: String): String = + protected def checkpointFile(name: String): String = new Path(new Path(resolvedCheckpointRoot), name).toUri.toString /** @@ -259,20 +229,25 @@ class StreamExecution( */ def start(): Unit = { logInfo(s"Starting $prettyIdString. Use $resolvedCheckpointRoot to store the query checkpoint.") - microBatchThread.setDaemon(true) - microBatchThread.start() + queryExecutionThread.setDaemon(true) + queryExecutionThread.start() startLatch.await() // Wait until thread started and QueryStart event has been posted } /** - * Repeatedly attempts to run batches as data arrives. + * Run the activated stream until stopped. + */ + protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit + + /** + * Activate the stream and then wrap a callout to runActivatedStream, handling start and stop. * * Note that this method ensures that [[QueryStartedEvent]] and [[QueryTerminatedEvent]] are * posted such that listeners are guaranteed to get a start event before a termination. * Furthermore, this method also ensures that [[QueryStartedEvent]] event is posted before the * `start()` method returns. */ - private def runBatches(): Unit = { + private def runStream(): Unit = { try { sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, interruptOnCancel = true) @@ -295,56 +270,18 @@ class StreamExecution( logicalPlan // Isolated spark session to run the batches with. - val sparkSessionToRunBatches = sparkSession.cloneSession() + val sparkSessionForStream = sparkSession.cloneSession() // Adaptive execution can change num shuffle partitions, disallow - sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") + sparkSessionForStream.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") // Disable cost-based join optimization as we do not want stateful operations to be rearranged - sparkSessionToRunBatches.conf.set(SQLConf.CBO_ENABLED.key, "false") + sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf) + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` initializationLatch.countDown() - - triggerExecutor.execute(() => { - startTrigger() - - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets(sparkSessionToRunBatches) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch(sparkSessionToRunBatches) - } - } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - batchCommitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } - } - updateStatusMessage("Waiting for next trigger") - isActive - }) + runActivatedStream(sparkSessionForStream) updateStatusMessage("Stopped") } else { // `stop()` is already called. Let `finally` finish the cleanup. @@ -373,7 +310,7 @@ class StreamExecution( if (!NonFatal(e)) { throw e } - } finally microBatchThread.runUninterruptibly { + } finally queryExecutionThread.runUninterruptibly { // The whole `finally` block must run inside `runUninterruptibly` to avoid being interrupted // when a query is stopped by the user. We need to make sure the following codes finish // otherwise it may throw `InterruptedException` to `UncaughtExceptionHandler` (SPARK-21248). @@ -410,12 +347,12 @@ class StreamExecution( } } } finally { - awaitBatchLock.lock() + awaitProgressLock.lock() try { // Wake up any threads that are waiting for the stream to progress. - awaitBatchLockCondition.signalAll() + awaitProgressLockCondition.signalAll() } finally { - awaitBatchLock.unlock() + awaitProgressLock.unlock() } terminationLatch.countDown() } @@ -448,296 +385,6 @@ class StreamExecution( } } - /** - * Populate the start offsets to start the execution at the current offsets stored in the sink - * (i.e. avoid reprocessing data that we have already processed). This function must be called - * before any processing occurs and will populate the following fields: - * - currentBatchId - * - committedOffsets - * - availableOffsets - * The basic structure of this method is as follows: - * - * Identify (from the offset log) the offsets used to run the last batch - * IF last batch exists THEN - * Set the next batch to be executed as the last recovered batch - * Check the commit log to see which batch was committed last - * IF the last batch was committed THEN - * Call getBatch using the last batch start and end offsets - * // ^^^^ above line is needed since some sources assume last batch always re-executes - * Setup for a new batch i.e., start = last batch end, and identify new end - * DONE - * ELSE - * Identify a brand new batch - * DONE - */ - private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { - offsetLog.getLatest() match { - case Some((latestBatchId, nextOffsets)) => - /* First assume that we are re-executing the latest known batch - * in the offset log */ - currentBatchId = latestBatchId - availableOffsets = nextOffsets.toStreamProgress(sources) - /* Initialize committed offsets to a committed batch, which at this - * is the second latest batch id in the offset log. */ - if (latestBatchId != 0) { - val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse { - throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist") - } - committedOffsets = secondLatestBatchId.toStreamProgress(sources) - } - - // update offset metadata - nextOffsets.metadata.foreach { metadata => - OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) - offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) - } - - /* identify the current batch id: if commit log indicates we successfully processed the - * latest batch id in the offset log, then we can safely move to the next batch - * i.e., committedBatchId + 1 */ - batchCommitLog.getLatest() match { - case Some((latestCommittedBatchId, _)) => - if (latestBatchId == latestCommittedBatchId) { - /* The last batch was successfully committed, so we can safely process a - * new next batch but first: - * Make a call to getBatch using the offsets from previous batch. - * because certain sources (e.g., KafkaSource) assume on restart the last - * batch will be executed before getOffset is called again. */ - availableOffsets.foreach { ao: (Source, Offset) => - val (source, end) = ao - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } - } - currentBatchId = latestCommittedBatchId + 1 - committedOffsets ++= availableOffsets - // Construct a new batch be recomputing availableOffsets - constructNextBatch() - } else if (latestCommittedBatchId < latestBatchId - 1) { - logWarning(s"Batch completion log latest batch id is " + - s"${latestCommittedBatchId}, which is not trailing " + - s"batchid $latestBatchId by one") - } - case None => logInfo("no commit log present") - } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + - s"$committedOffsets and available offsets $availableOffsets") - case None => // We are starting this stream for the first time. - logInfo(s"Starting new streaming query.") - currentBatchId = 0 - constructNextBatch() - } - } - - /** - * Returns true if there is any new data available to be processed. - */ - private def dataAvailable: Boolean = { - availableOffsets.exists { - case (source, available) => - committedOffsets - .get(source) - .map(committed => committed != available) - .getOrElse(true) - } - } - - /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. - */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitBatchLock.lock() - try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false - } - } finally { - awaitBatchLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } - } - - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds - - updateStatusMessage("Writing offsets to log") - reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, - availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), - s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") - logInfo(s"Committed offsets for batch $currentBatchId. " + - s"Metadata ${offsetSeqMetadata.toString}") - - // NOTE: The following code is correct because runBatches() processes exactly one - // batch at a time. If we add pipeline parallelism (multiple batches in flight at - // the same time), this cleanup logic will need to change. - - // Now that we've updated the scheduler's persistent checkpoint, it is safe for the - // sources to discard data from the previous batch. - if (currentBatchId != 0) { - val prevBatchOff = offsetLog.get(currentBatchId - 1) - if (prevBatchOff.isDefined) { - prevBatchOff.get.toStreamProgress(sources).foreach { - case (src, off) => src.commit(off) - } - } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") - } - } - - // It is now safe to discard the metadata beyond the minimum number to retain. - // Note that purge is exclusive, i.e. it purges everything before the target ID. - if (minBatchesToRetain < currentBatchId) { - offsetLog.purge(currentBatchId - minBatchesToRetain) - batchCommitLog.purge(currentBatchId - minBatchesToRetain) - } - } - } else { - awaitBatchLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitBatchLockCondition.signalAll() - } finally { - awaitBatchLock.unlock() - } - } - } - - /** - * Processes any data available between `availableOffsets` and `committedOffsets`. - * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. - */ - private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { - // Request unprocessed data from all sources. - newData = reportTimeTaken("getBatch") { - availableOffsets.flatMap { - case (source, available) - if committedOffsets.get(source).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(source) - val batch = source.getBatch(current, available) - assert(batch.isStreaming, - s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + - s"${batch.queryExecution.logical}") - logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) - case _ => None - } - } - - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] - // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { - case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, - s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan - }.getOrElse { - LocalRelation(output, isStreaming = true) - } - } - - // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) - case ct: CurrentTimestamp => - CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, - ct.dataType) - case cd: CurrentDate => - CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, - cd.dataType, cd.timeZoneId) - } - - reportTimeTaken("queryPlanning") { - lastExecution = new IncrementalExecution( - sparkSessionToRunBatch, - triggerLogicalPlan, - outputMode, - checkpointFile("state"), - runId, - currentBatchId, - offsetSeqMetadata) - lastExecution.executedPlan // Force the lazy generation of execution plan - } - - val nextBatch = - new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) - - reportTimeTaken("addBatch") { - SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) - } - } - - awaitBatchLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitBatchLockCondition.signalAll() - } finally { - awaitBatchLock.unlock() - } - } - override protected def postEvent(event: StreamingQueryListener.Event): Unit = { sparkSession.streams.postListenerEvent(event) } @@ -762,10 +409,10 @@ class StreamExecution( // Set the state to TERMINATED so that the batching thread knows that it was interrupted // intentionally state.set(TERMINATED) - if (microBatchThread.isAlive) { + if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) - microBatchThread.interrupt() - microBatchThread.join() + queryExecutionThread.interrupt() + queryExecutionThread.join() // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak sparkSession.sparkContext.cancelJobGroup(runId.toString) } @@ -784,21 +431,21 @@ class StreamExecution( } while (notDone) { - awaitBatchLock.lock() + awaitProgressLock.lock() try { - awaitBatchLockCondition.await(100, TimeUnit.MILLISECONDS) + awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { throw streamDeathCause } } finally { - awaitBatchLock.unlock() + awaitProgressLock.unlock() } } logDebug(s"Unblocked at $newOffset for $source") } /** A flag to indicate that a batch has completed with no new data available. */ - @volatile private var noNewData = false + @volatile protected var noNewData = false /** * Assert that the await APIs should not be called in the stream thread. Otherwise, it may cause @@ -806,7 +453,7 @@ class StreamExecution( * the stream thread forever. */ private def assertAwaitThread(): Unit = { - if (microBatchThread eq Thread.currentThread) { + if (queryExecutionThread eq Thread.currentThread) { throw new IllegalStateException( "Cannot wait for a query state from the same thread that is running the query") } @@ -833,11 +480,11 @@ class StreamExecution( throw streamDeathCause } if (!isActive) return - awaitBatchLock.lock() + awaitProgressLock.lock() try { noNewData = false while (true) { - awaitBatchLockCondition.await(10000, TimeUnit.MILLISECONDS) + awaitProgressLockCondition.await(10000, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { throw streamDeathCause } @@ -846,7 +493,7 @@ class StreamExecution( } } } finally { - awaitBatchLock.unlock() + awaitProgressLock.unlock() } } @@ -900,7 +547,7 @@ class StreamExecution( |Current Available Offsets: $availableOffsets | |Current State: $state - |Thread State: ${microBatchThread.getState}""".stripMargin + |Thread State: ${queryExecutionThread.getState}""".stripMargin if (includeLogicalPlan) { debugString + s"\n\nLogical Plan:\n$logicalPlan" } else { @@ -908,7 +555,7 @@ class StreamExecution( } } - private def getBatchDescriptionString: String = { + protected def getBatchDescriptionString: String = { val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString Option(name).map(_ + "
    ").getOrElse("") + s"id = $id
    runId = $runId
    batch = $batchDescription" @@ -920,7 +567,7 @@ object StreamExecution { } /** - * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread - * and will use `classOf[StreamExecutionThread]` to check. + * A special thread to run the stream query. Some codes require to run in the QueryExecutionThread + * and will use `classOf[QueryxecutionThread]` to check. */ -abstract class StreamExecutionThread(name: String) extends UninterruptibleThread(name) +abstract class QueryExecutionThread(name: String) extends UninterruptibleThread(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 48b0ea20e5da1..555d6e23f9385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -237,7 +237,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - new StreamingQueryWrapper(new StreamExecution( + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 47bc452bda0d4..d6bef9ce07379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -260,8 +260,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche CheckLastBatch((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) + 1L - q.batchCommitLog.purge(commit) + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 7a2d9e3728208..c5b57bca18313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1024,7 +1024,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { expectedCompactInterval: Int): Boolean = { import CompactibleFileStreamLog._ - val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + val fileSource = getSourcesFromStreamingQuery(execution).head val metadataLog = fileSource invokePrivate _metadataLog() if (isCompactionBatch(batchId, expectedCompactInterval)) { @@ -1100,8 +1100,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("keep1", "keep2", "keep3"), AssertOnQuery("check getBatch") { execution: StreamExecution => val _sources = PrivateMethod[Seq[Source]]('sources) - val fileSource = - (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + val fileSource = getSourcesFromStreamingQuery(execution).head def verify(startId: Option[Int], endId: Int, expected: String*): Unit = { val start = startId.map(new FileStreamSourceOffset(_)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 8163a1f91e1ca..9e696b2236b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -276,7 +276,7 @@ class StreamSuite extends StreamTest { // Check the latest batchid in the commit log def CheckCommitLogLatestBatchId(expectedId: Int): AssertOnQuery = - AssertOnQuery(_.batchCommitLog.getLatest().get._1 == expectedId, + AssertOnQuery(_.commitLog.getLatest().get._1 == expectedId, s"commitLog's latest should be $expectedId") // Ensure that there has not been an incremental execution after restart diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 7a1ff89f2f3b2..fb88c5d327043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -300,12 +300,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be if (currentStream != null) currentStream.committedOffsets.toString else "not started" def threadState = - if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" - def threadStackTrace = if (currentStream != null && currentStream.microBatchThread.isAlive) { - s"Thread stack trace: ${currentStream.microBatchThread.getStackTrace.mkString("\n")}" - } else { - "" - } + if (currentStream != null && currentStream.queryExecutionThread.isAlive) "alive" else "dead" + + def threadStackTrace = + if (currentStream != null && currentStream.queryExecutionThread.isAlive) { + s"Thread stack trace: ${currentStream.queryExecutionThread.getStackTrace.mkString("\n")}" + } else { + "" + } def testState = s""" @@ -460,7 +462,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be verify(currentStream != null, "can not stop a stream that is not running") try failAfter(streamingTimeout) { currentStream.stop() - verify(!currentStream.microBatchThread.isAlive, + verify(!currentStream.queryExecutionThread.isAlive, s"microbatch thread not stopped") verify(!currentStream.isActive, "query.isActive() is false even after stopping") @@ -486,7 +488,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitTermination() } eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.microBatchThread.isAlive) + assert(!currentStream.queryExecutionThread.isAlive) } verify(currentStream.exception === Some(thrownException), s"incorrect exception returned by query.exception()") @@ -614,7 +616,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => failTest("Timed out waiting for stream", e) } finally { - if (currentStream != null && currentStream.microBatchThread.isAlive) { + if (currentStream != null && currentStream.queryExecutionThread.isAlive) { currentStream.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fa0313592b8e7..38aa5171314f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -300,7 +300,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() - q.batchCommitLog.purge(3) + q.commitLog.purge(3) // advance by a minute i.e., 90 seconds total clock.advance(60 * 1000L) true @@ -352,7 +352,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() - q.batchCommitLog.purge(3) + q.commitLog.purge(3) // advance by 60 days i.e., 90 days total clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index f813b77e3ce6b..ad4d3abd01aa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -173,12 +173,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi StopStream, // clears out StreamTest state AssertOnQuery { q => // both commit log and offset log contain the same (latest) batch id - q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) == + q.commitLog.getLatest().map(_._1).getOrElse(-1L) == q.offsetLog.getLatest().map(_._1).getOrElse(-2L) }, AssertOnQuery { q => // blow away commit log and sink result - q.batchCommitLog.purge(1) + q.commitLog.purge(1) q.sink.asInstanceOf[MemorySink].clear() true }, From 0ea2d8c12e49e30df6bbfa57d74134b25f96a196 Mon Sep 17 00:00:00 2001 From: zouchenjun Date: Thu, 14 Dec 2017 15:37:26 -0800 Subject: [PATCH 434/712] [SPARK-22496][SQL] thrift server adds operation logs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? since hive 2.0+ upgrades log4j to log4j2,a lot of [changes](https://issues.apache.org/jira/browse/HIVE-11304) are made working on it. as spark is not to ready to update its inner hive version(1.2.1) , so I manage to make little changes. the function registerCurrentOperationLog is moved from SQLOperstion to its parent class ExecuteStatementOperation so spark can use it. ## How was this patch tested? manual test Closes #19721 from ChenjunZou/operation-log. Author: zouchenjun Closes #19961 from ChenjunZou/spark-22496. --- .../cli/operation/ExecuteStatementOperation.java | 13 +++++++++++++ .../hive/service/cli/operation/SQLOperation.java | 12 ------------ .../SparkExecuteStatementOperation.scala | 1 + 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java index 3f2de108f069a..6740d3bb59dc3 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessor; import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory; +import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationType; import org.apache.hive.service.cli.session.HiveSession; @@ -67,4 +68,16 @@ protected void setConfOverlay(Map confOverlay) { this.confOverlay = confOverlay; } } + + protected void registerCurrentOperationLog() { + if (isOperationLogEnabled) { + if (operationLog == null) { + LOG.warn("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()); + isOperationLogEnabled = false; + return; + } + OperationLog.setCurrentOperationLog(operationLog); + } + } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index 5014cedd870b6..fd9108eb53ca9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -274,18 +274,6 @@ private Hive getSessionHive() throws HiveSQLException { } } - private void registerCurrentOperationLog() { - if (isOperationLogEnabled) { - if (operationLog == null) { - LOG.warn("Failed to get current OperationLog object of Operation: " + - getHandle().getHandleIdentifier()); - isOperationLogEnabled = false; - return; - } - OperationLog.setCurrentOperationLog(operationLog); - } - } - private void cleanup(OperationState state) throws HiveSQLException { setState(state); if (shouldRunAsync()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f5191fa9132bd..664bc20601eaa 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -170,6 +170,7 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { + registerCurrentOperationLog() try { execute() } catch { From 3fea5c4f19cb5369ff8bbeca80768a8aadb463f5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 14 Dec 2017 22:56:57 -0800 Subject: [PATCH 435/712] [SPARK-22787][TEST][SQL] Add a TPC-H query suite ## What changes were proposed in this pull request? Add a test suite to ensure all the TPC-H queries can be successfully analyzed, optimized and compiled without hitting the max iteration threshold. ## How was this patch tested? N/A Author: gatorsmile Closes #19982 from gatorsmile/testTPCH. --- .../datasources/text/TextFileFormat.scala | 7 +- sql/core/src/test/resources/tpch/q1.sql | 23 ++++ sql/core/src/test/resources/tpch/q10.sql | 34 ++++++ sql/core/src/test/resources/tpch/q11.sql | 29 +++++ sql/core/src/test/resources/tpch/q12.sql | 30 +++++ sql/core/src/test/resources/tpch/q13.sql | 22 ++++ sql/core/src/test/resources/tpch/q14.sql | 15 +++ sql/core/src/test/resources/tpch/q15.sql | 35 ++++++ sql/core/src/test/resources/tpch/q16.sql | 32 +++++ sql/core/src/test/resources/tpch/q17.sql | 19 +++ sql/core/src/test/resources/tpch/q18.sql | 35 ++++++ sql/core/src/test/resources/tpch/q19.sql | 37 ++++++ sql/core/src/test/resources/tpch/q2.sql | 46 ++++++++ sql/core/src/test/resources/tpch/q20.sql | 39 +++++++ sql/core/src/test/resources/tpch/q21.sql | 42 +++++++ sql/core/src/test/resources/tpch/q22.sql | 39 +++++++ sql/core/src/test/resources/tpch/q3.sql | 25 ++++ sql/core/src/test/resources/tpch/q4.sql | 23 ++++ sql/core/src/test/resources/tpch/q5.sql | 26 +++++ sql/core/src/test/resources/tpch/q6.sql | 11 ++ sql/core/src/test/resources/tpch/q7.sql | 41 +++++++ sql/core/src/test/resources/tpch/q8.sql | 39 +++++++ sql/core/src/test/resources/tpch/q9.sql | 34 ++++++ .../apache/spark/sql/BenchmarkQueryTest.scala | 78 +++++++++++++ .../apache/spark/sql/TPCDSQuerySuite.scala | 58 +-------- .../org/apache/spark/sql/TPCHQuerySuite.scala | 110 ++++++++++++++++++ 26 files changed, 872 insertions(+), 57 deletions(-) create mode 100644 sql/core/src/test/resources/tpch/q1.sql create mode 100644 sql/core/src/test/resources/tpch/q10.sql create mode 100644 sql/core/src/test/resources/tpch/q11.sql create mode 100644 sql/core/src/test/resources/tpch/q12.sql create mode 100644 sql/core/src/test/resources/tpch/q13.sql create mode 100644 sql/core/src/test/resources/tpch/q14.sql create mode 100644 sql/core/src/test/resources/tpch/q15.sql create mode 100644 sql/core/src/test/resources/tpch/q16.sql create mode 100644 sql/core/src/test/resources/tpch/q17.sql create mode 100644 sql/core/src/test/resources/tpch/q18.sql create mode 100644 sql/core/src/test/resources/tpch/q19.sql create mode 100644 sql/core/src/test/resources/tpch/q2.sql create mode 100644 sql/core/src/test/resources/tpch/q20.sql create mode 100644 sql/core/src/test/resources/tpch/q21.sql create mode 100644 sql/core/src/test/resources/tpch/q22.sql create mode 100644 sql/core/src/test/resources/tpch/q3.sql create mode 100644 sql/core/src/test/resources/tpch/q4.sql create mode 100644 sql/core/src/test/resources/tpch/q5.sql create mode 100644 sql/core/src/test/resources/tpch/q6.sql create mode 100644 sql/core/src/test/resources/tpch/q7.sql create mode 100644 sql/core/src/test/resources/tpch/q8.sql create mode 100644 sql/core/src/test/resources/tpch/q9.sql create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 8a6ab303fc0f0..c661e9bd3b94c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -116,9 +116,10 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) } - private def readToUnsafeMem(conf: Broadcast[SerializableConfiguration], - requiredSchema: StructType, wholeTextMode: Boolean): - (PartitionedFile) => Iterator[UnsafeRow] = { + private def readToUnsafeMem( + conf: Broadcast[SerializableConfiguration], + requiredSchema: StructType, + wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value diff --git a/sql/core/src/test/resources/tpch/q1.sql b/sql/core/src/test/resources/tpch/q1.sql new file mode 100644 index 0000000000000..73eb8d8417155 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q1.sql @@ -0,0 +1,23 @@ +-- using default substitutions + +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus diff --git a/sql/core/src/test/resources/tpch/q10.sql b/sql/core/src/test/resources/tpch/q10.sql new file mode 100644 index 0000000000000..3b2ae588dee63 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q10.sql @@ -0,0 +1,34 @@ +-- using default substitutions + +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20 diff --git a/sql/core/src/test/resources/tpch/q11.sql b/sql/core/src/test/resources/tpch/q11.sql new file mode 100644 index 0000000000000..531e78c21b047 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q11.sql @@ -0,0 +1,29 @@ +-- using default substitutions + +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc diff --git a/sql/core/src/test/resources/tpch/q12.sql b/sql/core/src/test/resources/tpch/q12.sql new file mode 100644 index 0000000000000..d3e70eb481a9d --- /dev/null +++ b/sql/core/src/test/resources/tpch/q12.sql @@ -0,0 +1,30 @@ +-- using default substitutions + +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode diff --git a/sql/core/src/test/resources/tpch/q13.sql b/sql/core/src/test/resources/tpch/q13.sql new file mode 100644 index 0000000000000..3375002c5fd44 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q13.sql @@ -0,0 +1,22 @@ +-- using default substitutions + +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc diff --git a/sql/core/src/test/resources/tpch/q14.sql b/sql/core/src/test/resources/tpch/q14.sql new file mode 100644 index 0000000000000..753ea568914cb --- /dev/null +++ b/sql/core/src/test/resources/tpch/q14.sql @@ -0,0 +1,15 @@ +-- using default substitutions + +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month diff --git a/sql/core/src/test/resources/tpch/q15.sql b/sql/core/src/test/resources/tpch/q15.sql new file mode 100644 index 0000000000000..64d0b48ec09ae --- /dev/null +++ b/sql/core/src/test/resources/tpch/q15.sql @@ -0,0 +1,35 @@ +-- using default substitutions + +with revenue0 as + (select + l_suppkey as supplier_no, + sum(l_extendedprice * (1 - l_discount)) as total_revenue + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) + + +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey + diff --git a/sql/core/src/test/resources/tpch/q16.sql b/sql/core/src/test/resources/tpch/q16.sql new file mode 100644 index 0000000000000..a6ac68898ec8d --- /dev/null +++ b/sql/core/src/test/resources/tpch/q16.sql @@ -0,0 +1,32 @@ +-- using default substitutions + +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size diff --git a/sql/core/src/test/resources/tpch/q17.sql b/sql/core/src/test/resources/tpch/q17.sql new file mode 100644 index 0000000000000..74fb1f653a945 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q17.sql @@ -0,0 +1,19 @@ +-- using default substitutions + +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ) diff --git a/sql/core/src/test/resources/tpch/q18.sql b/sql/core/src/test/resources/tpch/q18.sql new file mode 100644 index 0000000000000..210fba19ec7f7 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q18.sql @@ -0,0 +1,35 @@ +-- using default substitutions + +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100 \ No newline at end of file diff --git a/sql/core/src/test/resources/tpch/q19.sql b/sql/core/src/test/resources/tpch/q19.sql new file mode 100644 index 0000000000000..c07327da3ac29 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q19.sql @@ -0,0 +1,37 @@ +-- using default substitutions + +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) diff --git a/sql/core/src/test/resources/tpch/q2.sql b/sql/core/src/test/resources/tpch/q2.sql new file mode 100644 index 0000000000000..d0e3b7e13e301 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q2.sql @@ -0,0 +1,46 @@ +-- using default substitutions + +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100 diff --git a/sql/core/src/test/resources/tpch/q20.sql b/sql/core/src/test/resources/tpch/q20.sql new file mode 100644 index 0000000000000..e161d340b9b6b --- /dev/null +++ b/sql/core/src/test/resources/tpch/q20.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name diff --git a/sql/core/src/test/resources/tpch/q21.sql b/sql/core/src/test/resources/tpch/q21.sql new file mode 100644 index 0000000000000..fdcdfbcf7976e --- /dev/null +++ b/sql/core/src/test/resources/tpch/q21.sql @@ -0,0 +1,42 @@ +-- using default substitutions + +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100 \ No newline at end of file diff --git a/sql/core/src/test/resources/tpch/q22.sql b/sql/core/src/test/resources/tpch/q22.sql new file mode 100644 index 0000000000000..1d7706e9a0bf4 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q22.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode diff --git a/sql/core/src/test/resources/tpch/q3.sql b/sql/core/src/test/resources/tpch/q3.sql new file mode 100644 index 0000000000000..948d6bcf12f68 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q3.sql @@ -0,0 +1,25 @@ +-- using default substitutions + +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10 diff --git a/sql/core/src/test/resources/tpch/q4.sql b/sql/core/src/test/resources/tpch/q4.sql new file mode 100644 index 0000000000000..67330e36a066d --- /dev/null +++ b/sql/core/src/test/resources/tpch/q4.sql @@ -0,0 +1,23 @@ +-- using default substitutions + +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority diff --git a/sql/core/src/test/resources/tpch/q5.sql b/sql/core/src/test/resources/tpch/q5.sql new file mode 100644 index 0000000000000..b973e9f0a0956 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q5.sql @@ -0,0 +1,26 @@ +-- using default substitutions + +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc diff --git a/sql/core/src/test/resources/tpch/q6.sql b/sql/core/src/test/resources/tpch/q6.sql new file mode 100644 index 0000000000000..22294579ee043 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q6.sql @@ -0,0 +1,11 @@ +-- using default substitutions + +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between .06 - 0.01 and .06 + 0.01 + and l_quantity < 24 diff --git a/sql/core/src/test/resources/tpch/q7.sql b/sql/core/src/test/resources/tpch/q7.sql new file mode 100644 index 0000000000000..21105c0519e1b --- /dev/null +++ b/sql/core/src/test/resources/tpch/q7.sql @@ -0,0 +1,41 @@ +-- using default substitutions + +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + year(l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year diff --git a/sql/core/src/test/resources/tpch/q8.sql b/sql/core/src/test/resources/tpch/q8.sql new file mode 100644 index 0000000000000..81d81871c4920 --- /dev/null +++ b/sql/core/src/test/resources/tpch/q8.sql @@ -0,0 +1,39 @@ +-- using default substitutions + +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year diff --git a/sql/core/src/test/resources/tpch/q9.sql b/sql/core/src/test/resources/tpch/q9.sql new file mode 100644 index 0000000000000..a4e8e8382be6a --- /dev/null +++ b/sql/core/src/test/resources/tpch/q9.sql @@ -0,0 +1,34 @@ +-- using default substitutions + +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + year(o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala new file mode 100644 index 0000000000000..7037749f14478 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + + // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting + // the max iteration of analyzer/optimizer batches. + assert(Utils.isTesting, "spark.testing is not set to true") + + /** + * Drop all the tables + */ + protected override def afterAll(): Unit = { + try { + // For debugging dump some statistics about how much time was spent in various optimizer rules + logWarning(RuleExecutor.dumpTimeSpent()) + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + override def beforeAll() { + super.beforeAll() + RuleExecutor.resetTime() + } + + protected def checkGeneratedCode(plan: SparkPlan): Unit = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() + plan foreach { + case s: WholeStageCodegenExec => + codegenSubtrees += s + case s => s + } + codegenSubtrees.toSeq.foreach { subtree => + val code = subtree.doCodeGen()._2 + try { + // Just check the generated code can be properly compiled + CodeGenerator.compile(code) + } catch { + case e: Exception => + val msg = + s""" + |failed to compile: + |Subtree: + |$subtree + |Generated code: + |${CodeFormatter.format(code)} + """.stripMargin + throw new Exception(msg, e) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index dd427a5c52edf..1a584187a06e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -17,41 +17,18 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} -import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.resourceToString -import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils /** - * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized - * without hitting the max iteration threshold. + * This test suite ensures all the TPC-DS queries can be successfully analyzed, optimized + * and compiled without hitting the max iteration threshold. */ -class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { - - // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting - // the max iteration of analyzer/optimizer batches. - assert(Utils.isTesting, "spark.testing is not set to true") - - /** - * Drop all the tables - */ - protected override def afterAll(): Unit = { - try { - // For debugging dump some statistics about how much time was spent in various optimizer rules - logWarning(RuleExecutor.dumpTimeSpent()) - spark.sessionState.catalog.reset() - } finally { - super.afterAll() - } - } +class TPCDSQuerySuite extends BenchmarkQueryTest { override def beforeAll() { super.beforeAll() + sql( """ |CREATE TABLE `catalog_page` ( @@ -350,33 +327,6 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") - private def checkGeneratedCode(plan: SparkPlan): Unit = { - val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() - plan foreach { - case s: WholeStageCodegenExec => - codegenSubtrees += s - case s => s - } - codegenSubtrees.toSeq.foreach { subtree => - val code = subtree.doCodeGen()._2 - try { - // Just check the generated code can be properly compiled - CodeGenerator.compile(code) - } catch { - case e: Exception => - val msg = - s""" - |failed to compile: - |Subtree: - |$subtree - |Generated code: - |${CodeFormatter.format(code)} - """.stripMargin - throw new Exception(msg, e) - } - } - } - tpcdsQueries.foreach { name => val queryString = resourceToString(s"tpcds/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala new file mode 100644 index 0000000000000..69ac92eaed7dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.util.Utils + +/** + * This test suite ensures all the TPC-H queries can be successfully analyzed, optimized + * and compiled without hitting the max iteration threshold. + */ +class TPCHQuerySuite extends BenchmarkQueryTest { + + override def beforeAll() { + super.beforeAll() + + sql( + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` STRING, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin) + } + + val tpchQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "q16", "q17", "q18", "q19", "q20", "q21", "q22") + + tpchQueries.foreach { name => + val queryString = resourceToString(s"tpch/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } +} From 3775dd31ee86c32b6161ca99d8fd5cfd7c1a758e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 14 Dec 2017 23:11:13 -0800 Subject: [PATCH 436/712] [SPARK-22753][SQL] Get rid of dataSource.writeAndRead ## What changes were proposed in this pull request? As the discussion in https://github.com/apache/spark/pull/16481 and https://github.com/apache/spark/pull/18975#discussion_r155454606 Currently the BaseRelation returned by `dataSource.writeAndRead` only used in `CreateDataSourceTableAsSelect`, planForWriting and writeAndRead has some common code paths. In this patch I removed the writeAndRead function and added the getRelation function which only use in `CreateDataSourceTableAsSelectCommand` while saving data to non-existing table. ## How was this patch tested? Existing UT Author: Yuanjian Li Closes #19941 from xuanyuanking/SPARK-22753. --- .../sql/execution/command/InsertIntoDataSourceDirCommand.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 9e3519073303c..1dc24b3d221cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -67,8 +67,7 @@ case class InsertIntoDataSourceDirCommand( val saveMode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists try { - sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query)) - dataSource.writeAndRead(saveMode, query) + sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query)).toRdd } catch { case ex: AnalysisException => logError(s"Failed to write to directory " + storage.locationUri.toString, ex) From e58f275678fb4f904124a4a2a1762f04c835eb0e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 15 Dec 2017 09:46:15 -0800 Subject: [PATCH 437/712] Revert "[SPARK-22496][SQL] thrift server adds operation logs" This reverts commit 0ea2d8c12e49e30df6bbfa57d74134b25f96a196. --- .../cli/operation/ExecuteStatementOperation.java | 13 ------------- .../hive/service/cli/operation/SQLOperation.java | 12 ++++++++++++ .../SparkExecuteStatementOperation.scala | 1 - 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java index 6740d3bb59dc3..3f2de108f069a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java @@ -23,7 +23,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessor; import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory; -import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationType; import org.apache.hive.service.cli.session.HiveSession; @@ -68,16 +67,4 @@ protected void setConfOverlay(Map confOverlay) { this.confOverlay = confOverlay; } } - - protected void registerCurrentOperationLog() { - if (isOperationLogEnabled) { - if (operationLog == null) { - LOG.warn("Failed to get current OperationLog object of Operation: " + - getHandle().getHandleIdentifier()); - isOperationLogEnabled = false; - return; - } - OperationLog.setCurrentOperationLog(operationLog); - } - } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index fd9108eb53ca9..5014cedd870b6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -274,6 +274,18 @@ private Hive getSessionHive() throws HiveSQLException { } } + private void registerCurrentOperationLog() { + if (isOperationLogEnabled) { + if (operationLog == null) { + LOG.warn("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()); + isOperationLogEnabled = false; + return; + } + OperationLog.setCurrentOperationLog(operationLog); + } + } + private void cleanup(OperationState state) throws HiveSQLException { setState(state); if (shouldRunAsync()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..f5191fa9132bd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -170,7 +170,6 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - registerCurrentOperationLog() try { execute() } catch { From 9fafa8209c51adc2a22b89aedf9af7b5e29e0059 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 15 Dec 2017 09:56:22 -0800 Subject: [PATCH 438/712] [SPARK-22800][TEST][SQL] Add a SSB query suite ## What changes were proposed in this pull request? Add a test suite to ensure all the [SSB (Star Schema Benchmark)](https://www.cs.umb.edu/~poneil/StarSchemaB.PDF) queries can be successfully analyzed, optimized and compiled without hitting the max iteration threshold. ## How was this patch tested? Added `SSBQuerySuite`. Author: Takeshi Yamamuro Closes #19990 from maropu/SPARK-22800. --- sql/core/src/test/resources/ssb/1.1.sql | 6 ++ sql/core/src/test/resources/ssb/1.2.sql | 6 ++ sql/core/src/test/resources/ssb/1.3.sql | 6 ++ sql/core/src/test/resources/ssb/2.1.sql | 9 ++ sql/core/src/test/resources/ssb/2.2.sql | 9 ++ sql/core/src/test/resources/ssb/2.3.sql | 9 ++ sql/core/src/test/resources/ssb/3.1.sql | 10 +++ sql/core/src/test/resources/ssb/3.2.sql | 10 +++ sql/core/src/test/resources/ssb/3.3.sql | 12 +++ sql/core/src/test/resources/ssb/3.4.sql | 12 +++ sql/core/src/test/resources/ssb/4.1.sql | 11 +++ sql/core/src/test/resources/ssb/4.2.sql | 12 +++ sql/core/src/test/resources/ssb/4.3.sql | 12 +++ .../org/apache/spark/sql/SSBQuerySuite.scala | 87 +++++++++++++++++++ .../org/apache/spark/sql/TPCHQuerySuite.scala | 2 - 15 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/resources/ssb/1.1.sql create mode 100644 sql/core/src/test/resources/ssb/1.2.sql create mode 100644 sql/core/src/test/resources/ssb/1.3.sql create mode 100644 sql/core/src/test/resources/ssb/2.1.sql create mode 100644 sql/core/src/test/resources/ssb/2.2.sql create mode 100644 sql/core/src/test/resources/ssb/2.3.sql create mode 100644 sql/core/src/test/resources/ssb/3.1.sql create mode 100644 sql/core/src/test/resources/ssb/3.2.sql create mode 100644 sql/core/src/test/resources/ssb/3.3.sql create mode 100644 sql/core/src/test/resources/ssb/3.4.sql create mode 100644 sql/core/src/test/resources/ssb/4.1.sql create mode 100644 sql/core/src/test/resources/ssb/4.2.sql create mode 100644 sql/core/src/test/resources/ssb/4.3.sql create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SSBQuerySuite.scala diff --git a/sql/core/src/test/resources/ssb/1.1.sql b/sql/core/src/test/resources/ssb/1.1.sql new file mode 100644 index 0000000000000..62da3029469e4 --- /dev/null +++ b/sql/core/src/test/resources/ssb/1.1.sql @@ -0,0 +1,6 @@ +select sum(lo_extendedprice*lo_discount) as revenue + from lineorder, date + where lo_orderdate = d_datekey + and d_year = 1993 + and lo_discount between 1 and 3 + and lo_quantity < 25 diff --git a/sql/core/src/test/resources/ssb/1.2.sql b/sql/core/src/test/resources/ssb/1.2.sql new file mode 100644 index 0000000000000..1657bfd4b2a32 --- /dev/null +++ b/sql/core/src/test/resources/ssb/1.2.sql @@ -0,0 +1,6 @@ +select sum(lo_extendedprice*lo_discount) as revenue + from lineorder, date + where lo_orderdate = d_datekey + and d_yearmonthnum = 199401 + and lo_discount between 4 and 6 + and lo_quantity between 26 and 35 diff --git a/sql/core/src/test/resources/ssb/1.3.sql b/sql/core/src/test/resources/ssb/1.3.sql new file mode 100644 index 0000000000000..e9bbf51f051dd --- /dev/null +++ b/sql/core/src/test/resources/ssb/1.3.sql @@ -0,0 +1,6 @@ +select sum(lo_extendedprice*lo_discount) as revenue + from lineorder, date + where lo_orderdate = d_datekey + and d_weeknuminyear = 6 and d_year = 1994 + and lo_discount between 5 and 7 + and lo_quantity between 36 and 40 diff --git a/sql/core/src/test/resources/ssb/2.1.sql b/sql/core/src/test/resources/ssb/2.1.sql new file mode 100644 index 0000000000000..00d402785320b --- /dev/null +++ b/sql/core/src/test/resources/ssb/2.1.sql @@ -0,0 +1,9 @@ +select sum(lo_revenue), d_year, p_brand1 + from lineorder, date, part, supplier + where lo_orderdate = d_datekey + and lo_partkey = p_partkey + and lo_suppkey = s_suppkey + and p_category = 'MFGR#12' + and s_region = 'AMERICA' + group by d_year, p_brand1 + order by d_year, p_brand1 diff --git a/sql/core/src/test/resources/ssb/2.2.sql b/sql/core/src/test/resources/ssb/2.2.sql new file mode 100644 index 0000000000000..c53a925f75ddb --- /dev/null +++ b/sql/core/src/test/resources/ssb/2.2.sql @@ -0,0 +1,9 @@ +select sum(lo_revenue), d_year, p_brand1 + from lineorder, date, part, supplier + where lo_orderdate = d_datekey + and lo_partkey = p_partkey + and lo_suppkey = s_suppkey + and p_brand1 between 'MFGR#2221' and 'MFGR#2228' + and s_region = 'ASIA' + group by d_year, p_brand1 + order by d_year, p_brand1 diff --git a/sql/core/src/test/resources/ssb/2.3.sql b/sql/core/src/test/resources/ssb/2.3.sql new file mode 100644 index 0000000000000..e36530c65e428 --- /dev/null +++ b/sql/core/src/test/resources/ssb/2.3.sql @@ -0,0 +1,9 @@ +select sum(lo_revenue), d_year, p_brand1 + from lineorder, date, part, supplier + where lo_orderdate = d_datekey + and lo_partkey = p_partkey + and lo_suppkey = s_suppkey + and p_brand1 = 'MFGR#2221' + and s_region = 'EUROPE' + group by d_year, p_brand1 + order by d_year, p_brand1 diff --git a/sql/core/src/test/resources/ssb/3.1.sql b/sql/core/src/test/resources/ssb/3.1.sql new file mode 100644 index 0000000000000..663ec3fb3d5fa --- /dev/null +++ b/sql/core/src/test/resources/ssb/3.1.sql @@ -0,0 +1,10 @@ +select c_nation, s_nation, d_year, sum(lo_revenue) as revenue + from customer, lineorder, supplier, date + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_orderdate = d_datekey + and c_region = 'ASIA' + and s_region = 'ASIA' + and d_year >= 1992 and d_year <= 1997 + group by c_nation, s_nation, d_year + order by d_year asc, revenue desc diff --git a/sql/core/src/test/resources/ssb/3.2.sql b/sql/core/src/test/resources/ssb/3.2.sql new file mode 100644 index 0000000000000..e1eaf10cded62 --- /dev/null +++ b/sql/core/src/test/resources/ssb/3.2.sql @@ -0,0 +1,10 @@ +select c_city, s_city, d_year, sum(lo_revenue) as revenue + from customer, lineorder, supplier, date + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_orderdate = d_datekey + and c_nation = 'UNITED STATES' + and s_nation = 'UNITED STATES' + and d_year >= 1992 and d_year <= 1997 + group by c_city, s_city, d_year + order by d_year asc, revenue desc diff --git a/sql/core/src/test/resources/ssb/3.3.sql b/sql/core/src/test/resources/ssb/3.3.sql new file mode 100644 index 0000000000000..f2cb44f15a761 --- /dev/null +++ b/sql/core/src/test/resources/ssb/3.3.sql @@ -0,0 +1,12 @@ +select c_city, s_city, d_year, sum(lo_revenue) as revenue + from customer, lineorder, supplier, date + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_orderdate = d_datekey + and c_nation = 'UNITED KINGDOM' + and (c_city='UNITED KI1' or c_city='UNITED KI5') + and (s_city='UNITED KI1' or s_city='UNITED KI5') + and s_nation = 'UNITED KINGDOM' + and d_year >= 1992 and d_year <= 1997 + group by c_city, s_city, d_year + order by d_year asc, revenue desc diff --git a/sql/core/src/test/resources/ssb/3.4.sql b/sql/core/src/test/resources/ssb/3.4.sql new file mode 100644 index 0000000000000..936f2464a55f1 --- /dev/null +++ b/sql/core/src/test/resources/ssb/3.4.sql @@ -0,0 +1,12 @@ +select c_city, s_city, d_year, sum(lo_revenue) as revenue + from customer, lineorder, supplier, date + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_orderdate = d_datekey + and c_nation = 'UNITED KINGDOM' + and (c_city='UNITED KI1' or c_city='UNITED KI5') + and (s_city='UNITED KI1' or s_city='UNITED KI5') + and s_nation = 'UNITED KINGDOM' + and d_yearmonth = 'Dec1997' + group by c_city, s_city, d_year + order by d_year asc, revenue desc diff --git a/sql/core/src/test/resources/ssb/4.1.sql b/sql/core/src/test/resources/ssb/4.1.sql new file mode 100644 index 0000000000000..af19ecdf58c1f --- /dev/null +++ b/sql/core/src/test/resources/ssb/4.1.sql @@ -0,0 +1,11 @@ +select d_year, c_nation, sum(lo_revenue-lo_supplycost) as profit1 + from date, customer, supplier, part, lineorder + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_partkey = p_partkey + and lo_orderdate = d_datekey + and c_region = 'AMERICA' + and s_region = 'AMERICA' + and (p_mfgr = 'MFGR#1' or p_mfgr = 'MFGR#2') + group by d_year, c_nation + order by d_year, c_nation diff --git a/sql/core/src/test/resources/ssb/4.2.sql b/sql/core/src/test/resources/ssb/4.2.sql new file mode 100644 index 0000000000000..2ffe6e2897522 --- /dev/null +++ b/sql/core/src/test/resources/ssb/4.2.sql @@ -0,0 +1,12 @@ +select d_year, s_nation, p_category, sum(lo_revenue-lo_supplycost) as profit1 + from date, customer, supplier, part, lineorder + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_partkey = p_partkey + and lo_orderdate = d_datekey + and c_region = 'AMERICA' + and s_region = 'AMERICA' + and (d_year = 1997 or d_year = 1998) + and (p_mfgr = 'MFGR#1' or p_mfgr = 'MFGR#2') + group by d_year, s_nation, p_category + order by d_year, s_nation, p_category diff --git a/sql/core/src/test/resources/ssb/4.3.sql b/sql/core/src/test/resources/ssb/4.3.sql new file mode 100644 index 0000000000000..e24e2cc78c914 --- /dev/null +++ b/sql/core/src/test/resources/ssb/4.3.sql @@ -0,0 +1,12 @@ +select d_year, s_city, p_brand1, sum(lo_revenue-lo_supplycost) as profit1 + from date, customer, supplier, part, lineorder + where lo_custkey = c_custkey + and lo_suppkey = s_suppkey + and lo_partkey = p_partkey + and lo_orderdate = d_datekey + and c_region = 'AMERICA' + and s_nation = 'UNITED STATES' + and (d_year = 1997 or d_year = 1998) + and p_category = 'MFGR#14' + group by d_year, s_city, p_brand1 + order by d_year, s_city, p_brand1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SSBQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SSBQuerySuite.scala new file mode 100644 index 0000000000000..9a0c61b3304c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SSBQuerySuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.util.resourceToString + +/** + * This test suite ensures all the Star Schema Benchmark queries can be successfully analyzed, + * optimized and compiled without hitting the max iteration threshold. + */ +class SSBQuerySuite extends BenchmarkQueryTest { + + override def beforeAll { + super.beforeAll + + sql( + """ + |CREATE TABLE `part` (`p_partkey` INT, `p_name` STRING, `p_mfgr` STRING, + |`p_category` STRING, `p_brand1` STRING, `p_color` STRING, `p_type` STRING, `p_size` INT, + |`p_container` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `supplier` (`s_suppkey` INT, `s_name` STRING, `s_address` STRING, + |`s_city` STRING, `s_nation` STRING, `s_region` STRING, `s_phone` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` (`c_custkey` INT, `c_name` STRING, `c_address` STRING, + |`c_city` STRING, `c_nation` STRING, `c_region` STRING, `c_phone` STRING, + |`c_mktsegment` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `date` (`d_datekey` INT, `d_date` STRING, `d_dayofweek` STRING, + |`d_month` STRING, `d_year` INT, `d_yearmonthnum` INT, `d_yearmonth` STRING, + |`d_daynuminweek` INT, `d_daynuminmonth` INT, `d_daynuminyear` INT, `d_monthnuminyear` INT, + |`d_weeknuminyear` INT, `d_sellingseason` STRING, `d_lastdayinweekfl` STRING, + |`d_lastdayinmonthfl` STRING, `d_holidayfl` STRING, `d_weekdayfl` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `lineorder` (`lo_orderkey` INT, `lo_linenumber` INT, `lo_custkey` INT, + |`lo_partkey` INT, `lo_suppkey` INT, `lo_orderdate` INT, `lo_orderpriority` STRING, + |`lo_shippriority` STRING, `lo_quantity` INT, `lo_extendedprice` INT, + |`lo_ordertotalprice` INT, `lo_discount` INT, `lo_revenue` INT, `lo_supplycost` INT, + |`lo_tax` INT, `lo_commitdate` INT, `lo_shipmode` STRING) + |USING parquet + """.stripMargin) + } + + val ssbQueries = Seq( + "1.1", "1.2", "1.3", "2.1", "2.2", "2.3", "3.1", "3.2", "3.3", "3.4", "4.1", "4.2", "4.3") + + ssbQueries.foreach { name => + val queryString = resourceToString(s"ssb/$name.sql", + classLoader = Thread.currentThread.getContextClassLoader) + test(name) { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala index 69ac92eaed7dc..e3e700529bba7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.resourceToString -import org.apache.spark.util.Utils /** * This test suite ensures all the TPC-H queries can be successfully analyzed, optimized From 46776234a49742e94c64897322500582d7393d35 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 15 Dec 2017 09:58:31 -0800 Subject: [PATCH 439/712] [SPARK-22762][TEST] Basic tests for IfCoercion and CaseWhenCoercion ## What changes were proposed in this pull request? Basic tests for IfCoercion and CaseWhenCoercion ## How was this patch tested? N/A Author: Yuming Wang Closes #19949 from wangyum/SPARK-22762. --- .../typeCoercion/native/caseWhenCoercion.sql | 174 +++ .../inputs/typeCoercion/native/ifCoercion.sql | 174 +++ .../native/caseWhenCoercion.sql.out | 1232 +++++++++++++++++ .../typeCoercion/native/ifCoercion.sql.out | 1232 +++++++++++++++++ 4 files changed, 2812 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql new file mode 100644 index 0000000000000..a780529fdca8c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/caseWhenCoercion.sql @@ -0,0 +1,174 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; + +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as tinyint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as smallint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as int) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as bigint) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as float) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as double) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as decimal(10, 0)) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as string) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2' as binary) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as boolean) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t; +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00' as date) END FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql new file mode 100644 index 0000000000000..42597f169daec --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/ifCoercion.sql @@ -0,0 +1,174 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT IF(true, cast(1 as tinyint), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as smallint), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as smallint), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as smallint), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as int), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as int), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as int), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as bigint), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as bigint), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as bigint), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as float), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as float), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as float), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as double), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as double), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as double), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as string), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as string), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as string), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast('1' as binary), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as smallint)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as int)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as bigint)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as float)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as double)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as string)) FROM t; +SELECT IF(true, cast('1' as binary), cast('2' as binary)) FROM t; +SELECT IF(true, cast('1' as binary), cast(2 as boolean)) FROM t; +SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast(1 as boolean), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as smallint)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as int)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as bigint)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as float)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as double)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as string)) FROM t; +SELECT IF(true, cast(1 as boolean), cast('2' as binary)) FROM t; +SELECT IF(true, cast(1 as boolean), cast(2 as boolean)) FROM t; +SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as smallint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as int)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as bigint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as float)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as double)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as string)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2' as binary)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as boolean)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as tinyint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as smallint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as int)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as bigint)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as float)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as double)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as decimal(10, 0))) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as string)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2' as binary)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as boolean)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out new file mode 100644 index 0000000000000..a739f8d73181c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/caseWhenCoercion.sql.out @@ -0,0 +1,1232 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 145 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as tinyint) END FROM t +-- !query 1 schema +struct +-- !query 1 output +1 + + +-- !query 2 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as smallint) END FROM t +-- !query 2 schema +struct +-- !query 2 output +1 + + +-- !query 3 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as int) END FROM t +-- !query 3 schema +struct +-- !query 3 output +1 + + +-- !query 4 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as bigint) END FROM t +-- !query 4 schema +struct +-- !query 4 output +1 + + +-- !query 5 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as float) END FROM t +-- !query 5 schema +struct +-- !query 5 output +1.0 + + +-- !query 6 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as double) END FROM t +-- !query 6 schema +struct +-- !query 6 output +1.0 + + +-- !query 7 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as string) END FROM t +-- !query 8 schema +struct +-- !query 8 output +1 + + +-- !query 9 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2' as binary) END FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 10 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast(2 as boolean) END FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 11 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 12 +SELECT CASE WHEN true THEN cast(1 as tinyint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS TINYINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 13 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as tinyint) END FROM t +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as smallint) END FROM t +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as int) END FROM t +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as bigint) END FROM t +-- !query 16 schema +struct +-- !query 16 output +1 + + +-- !query 17 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as float) END FROM t +-- !query 17 schema +struct +-- !query 17 output +1.0 + + +-- !query 18 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as double) END FROM t +-- !query 18 schema +struct +-- !query 18 output +1.0 + + +-- !query 19 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 19 schema +struct +-- !query 19 output +1 + + +-- !query 20 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as string) END FROM t +-- !query 20 schema +struct +-- !query 20 output +1 + + +-- !query 21 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2' as binary) END FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 22 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast(2 as boolean) END FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 23 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 24 +SELECT CASE WHEN true THEN cast(1 as smallint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS SMALLINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 25 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as tinyint) END FROM t +-- !query 25 schema +struct +-- !query 25 output +1 + + +-- !query 26 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as smallint) END FROM t +-- !query 26 schema +struct +-- !query 26 output +1 + + +-- !query 27 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as int) END FROM t +-- !query 27 schema +struct +-- !query 27 output +1 + + +-- !query 28 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as bigint) END FROM t +-- !query 28 schema +struct +-- !query 28 output +1 + + +-- !query 29 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as float) END FROM t +-- !query 29 schema +struct +-- !query 29 output +1.0 + + +-- !query 30 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as double) END FROM t +-- !query 30 schema +struct +-- !query 30 output +1.0 + + +-- !query 31 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 31 schema +struct +-- !query 31 output +1 + + +-- !query 32 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as string) END FROM t +-- !query 32 schema +struct +-- !query 32 output +1 + + +-- !query 33 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2' as binary) END FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 34 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast(2 as boolean) END FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 35 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 36 +SELECT CASE WHEN true THEN cast(1 as int) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS INT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 37 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as tinyint) END FROM t +-- !query 37 schema +struct +-- !query 37 output +1 + + +-- !query 38 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as smallint) END FROM t +-- !query 38 schema +struct +-- !query 38 output +1 + + +-- !query 39 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as int) END FROM t +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as bigint) END FROM t +-- !query 40 schema +struct +-- !query 40 output +1 + + +-- !query 41 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as float) END FROM t +-- !query 41 schema +struct +-- !query 41 output +1.0 + + +-- !query 42 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as double) END FROM t +-- !query 42 schema +struct +-- !query 42 output +1.0 + + +-- !query 43 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 43 schema +struct +-- !query 43 output +1 + + +-- !query 44 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as string) END FROM t +-- !query 44 schema +struct +-- !query 44 output +1 + + +-- !query 45 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2' as binary) END FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 46 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast(2 as boolean) END FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 47 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 48 +SELECT CASE WHEN true THEN cast(1 as bigint) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BIGINT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 49 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as tinyint) END FROM t +-- !query 49 schema +struct +-- !query 49 output +1.0 + + +-- !query 50 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as smallint) END FROM t +-- !query 50 schema +struct +-- !query 50 output +1.0 + + +-- !query 51 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as int) END FROM t +-- !query 51 schema +struct +-- !query 51 output +1.0 + + +-- !query 52 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as bigint) END FROM t +-- !query 52 schema +struct +-- !query 52 output +1.0 + + +-- !query 53 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as float) END FROM t +-- !query 53 schema +struct +-- !query 53 output +1.0 + + +-- !query 54 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as double) END FROM t +-- !query 54 schema +struct +-- !query 54 output +1.0 + + +-- !query 55 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 55 schema +struct +-- !query 55 output +1.0 + + +-- !query 56 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as string) END FROM t +-- !query 56 schema +struct +-- !query 56 output +1.0 + + +-- !query 57 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2' as binary) END FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 58 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast(2 as boolean) END FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 59 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 60 +SELECT CASE WHEN true THEN cast(1 as float) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS FLOAT) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 61 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as tinyint) END FROM t +-- !query 61 schema +struct +-- !query 61 output +1.0 + + +-- !query 62 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as smallint) END FROM t +-- !query 62 schema +struct +-- !query 62 output +1.0 + + +-- !query 63 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as int) END FROM t +-- !query 63 schema +struct +-- !query 63 output +1.0 + + +-- !query 64 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as bigint) END FROM t +-- !query 64 schema +struct +-- !query 64 output +1.0 + + +-- !query 65 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as float) END FROM t +-- !query 65 schema +struct +-- !query 65 output +1.0 + + +-- !query 66 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as double) END FROM t +-- !query 66 schema +struct +-- !query 66 output +1.0 + + +-- !query 67 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 67 schema +struct +-- !query 67 output +1.0 + + +-- !query 68 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as string) END FROM t +-- !query 68 schema +struct +-- !query 68 output +1.0 + + +-- !query 69 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2' as binary) END FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 70 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast(2 as boolean) END FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 71 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 72 +SELECT CASE WHEN true THEN cast(1 as double) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DOUBLE) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 73 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as tinyint) END FROM t +-- !query 73 schema +struct +-- !query 73 output +1 + + +-- !query 74 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as smallint) END FROM t +-- !query 74 schema +struct +-- !query 74 output +1 + + +-- !query 75 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as int) END FROM t +-- !query 75 schema +struct +-- !query 75 output +1 + + +-- !query 76 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as bigint) END FROM t +-- !query 76 schema +struct +-- !query 76 output +1 + + +-- !query 77 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as float) END FROM t +-- !query 77 schema +struct +-- !query 77 output +1.0 + + +-- !query 78 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as double) END FROM t +-- !query 78 schema +struct +-- !query 78 output +1.0 + + +-- !query 79 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 79 schema +struct +-- !query 79 output +1 + + +-- !query 80 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as string) END FROM t +-- !query 80 schema +struct +-- !query 80 output +1 + + +-- !query 81 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2' as binary) END FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 82 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast(2 as boolean) END FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 83 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 84 +SELECT CASE WHEN true THEN cast(1 as decimal(10, 0)) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS DECIMAL(10,0)) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 85 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as tinyint) END FROM t +-- !query 85 schema +struct +-- !query 85 output +1 + + +-- !query 86 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as smallint) END FROM t +-- !query 86 schema +struct +-- !query 86 output +1 + + +-- !query 87 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as int) END FROM t +-- !query 87 schema +struct +-- !query 87 output +1 + + +-- !query 88 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as bigint) END FROM t +-- !query 88 schema +struct +-- !query 88 output +1 + + +-- !query 89 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as float) END FROM t +-- !query 89 schema +struct +-- !query 89 output +1 + + +-- !query 90 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as double) END FROM t +-- !query 90 schema +struct +-- !query 90 output +1 + + +-- !query 91 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 91 schema +struct +-- !query 91 output +1 + + +-- !query 92 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as string) END FROM t +-- !query 92 schema +struct +-- !query 92 output +1 + + +-- !query 93 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2' as binary) END FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS STRING) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 94 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast(2 as boolean) END FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS STRING) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 95 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 95 schema +struct +-- !query 95 output +1 + + +-- !query 96 +SELECT CASE WHEN true THEN cast(1 as string) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 96 schema +struct +-- !query 96 output +1 + + +-- !query 97 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as tinyint) END FROM t +-- !query 97 schema +struct<> +-- !query 97 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 98 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as smallint) END FROM t +-- !query 98 schema +struct<> +-- !query 98 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 99 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as int) END FROM t +-- !query 99 schema +struct<> +-- !query 99 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 100 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as bigint) END FROM t +-- !query 100 schema +struct<> +-- !query 100 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 101 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as float) END FROM t +-- !query 101 schema +struct<> +-- !query 101 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 102 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as double) END FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 103 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 104 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as string) END FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS STRING) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 105 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2' as binary) END FROM t +-- !query 105 schema +struct +-- !query 105 output +1 + + +-- !query 106 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast(2 as boolean) END FROM t +-- !query 106 schema +struct<> +-- !query 106 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 107 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 107 schema +struct<> +-- !query 107 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 108 +SELECT CASE WHEN true THEN cast('1' as binary) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 108 schema +struct<> +-- !query 108 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('1' AS BINARY) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 109 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as tinyint) END FROM t +-- !query 109 schema +struct<> +-- !query 109 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 110 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as smallint) END FROM t +-- !query 110 schema +struct<> +-- !query 110 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 111 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as int) END FROM t +-- !query 111 schema +struct<> +-- !query 111 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 112 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as bigint) END FROM t +-- !query 112 schema +struct<> +-- !query 112 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 113 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as float) END FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 114 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as double) END FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 115 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 116 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as string) END FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST(2 AS STRING) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 117 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2' as binary) END FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 118 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast(2 as boolean) END FROM t +-- !query 118 schema +struct +-- !query 118 output +true + + +-- !query 119 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 120 +SELECT CASE WHEN true THEN cast(1 as boolean) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST(1 AS BOOLEAN) ELSE CAST('2017-12-11 09:30:00' AS DATE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 121 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as tinyint) END FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 122 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as smallint) END FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 123 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as int) END FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 124 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as bigint) END FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 125 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as float) END FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 126 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as double) END FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 127 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 128 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as string) END FROM t +-- !query 128 schema +struct +-- !query 128 output +2017-12-12 09:30:00 + + +-- !query 129 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2' as binary) END FROM t +-- !query 129 schema +struct<> +-- !query 129 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 130 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast(2 as boolean) END FROM t +-- !query 130 schema +struct<> +-- !query 130 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 131 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 131 schema +struct +-- !query 131 output +2017-12-12 09:30:00 + + +-- !query 132 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00.0' as timestamp) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 132 schema +struct +-- !query 132 output +2017-12-12 09:30:00 + + +-- !query 133 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as tinyint) END FROM t +-- !query 133 schema +struct<> +-- !query 133 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS TINYINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 134 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as smallint) END FROM t +-- !query 134 schema +struct<> +-- !query 134 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS SMALLINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 135 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as int) END FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS INT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 136 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as bigint) END FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS BIGINT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 137 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as float) END FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS FLOAT) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 138 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as double) END FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS DOUBLE) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 139 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as decimal(10, 0)) END FROM t +-- !query 139 schema +struct<> +-- !query 139 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS DECIMAL(10,0)) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 140 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as string) END FROM t +-- !query 140 schema +struct +-- !query 140 output +2017-12-12 + + +-- !query 141 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2' as binary) END FROM t +-- !query 141 schema +struct<> +-- !query 141 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST('2' AS BINARY) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 142 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast(2 as boolean) END FROM t +-- !query 142 schema +struct<> +-- !query 142 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CASE WHEN true THEN CAST('2017-12-12 09:30:00' AS DATE) ELSE CAST(2 AS BOOLEAN) END' due to data type mismatch: THEN and ELSE expressions should all be same type or coercible to a common type; line 1 pos 7 + + +-- !query 143 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00.0' as timestamp) END FROM t +-- !query 143 schema +struct +-- !query 143 output +2017-12-12 00:00:00 + + +-- !query 144 +SELECT CASE WHEN true THEN cast('2017-12-12 09:30:00' as date) ELSE cast('2017-12-11 09:30:00' as date) END FROM t +-- !query 144 schema +struct +-- !query 144 output +2017-12-12 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out new file mode 100644 index 0000000000000..7097027872707 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/ifCoercion.sql.out @@ -0,0 +1,1232 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 145 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT IF(true, cast(1 as tinyint), cast(2 as tinyint)) FROM t +-- !query 1 schema +struct<(IF(true, CAST(1 AS TINYINT), CAST(2 AS TINYINT))):tinyint> +-- !query 1 output +1 + + +-- !query 2 +SELECT IF(true, cast(1 as tinyint), cast(2 as smallint)) FROM t +-- !query 2 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS SMALLINT), CAST(2 AS SMALLINT))):smallint> +-- !query 2 output +1 + + +-- !query 3 +SELECT IF(true, cast(1 as tinyint), cast(2 as int)) FROM t +-- !query 3 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS INT), CAST(2 AS INT))):int> +-- !query 3 output +1 + + +-- !query 4 +SELECT IF(true, cast(1 as tinyint), cast(2 as bigint)) FROM t +-- !query 4 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS BIGINT), CAST(2 AS BIGINT))):bigint> +-- !query 4 output +1 + + +-- !query 5 +SELECT IF(true, cast(1 as tinyint), cast(2 as float)) FROM t +-- !query 5 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS FLOAT), CAST(2 AS FLOAT))):float> +-- !query 5 output +1.0 + + +-- !query 6 +SELECT IF(true, cast(1 as tinyint), cast(2 as double)) FROM t +-- !query 6 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 6 output +1.0 + + +-- !query 7 +SELECT IF(true, cast(1 as tinyint), cast(2 as decimal(10, 0))) FROM t +-- !query 7 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 7 output +1 + + +-- !query 8 +SELECT IF(true, cast(1 as tinyint), cast(2 as string)) FROM t +-- !query 8 schema +struct<(IF(true, CAST(CAST(1 AS TINYINT) AS STRING), CAST(2 AS STRING))):string> +-- !query 8 output +1 + + +-- !query 9 +SELECT IF(true, cast(1 as tinyint), cast('2' as binary)) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2' AS BINARY)))' (tinyint and binary).; line 1 pos 7 + + +-- !query 10 +SELECT IF(true, cast(1 as tinyint), cast(2 as boolean)) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST(2 AS BOOLEAN)))' (tinyint and boolean).; line 1 pos 7 + + +-- !query 11 +SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (tinyint and timestamp).; line 1 pos 7 + + +-- !query 12 +SELECT IF(true, cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' (tinyint and date).; line 1 pos 7 + + +-- !query 13 +SELECT IF(true, cast(1 as smallint), cast(2 as tinyint)) FROM t +-- !query 13 schema +struct<(IF(true, CAST(1 AS SMALLINT), CAST(CAST(2 AS TINYINT) AS SMALLINT))):smallint> +-- !query 13 output +1 + + +-- !query 14 +SELECT IF(true, cast(1 as smallint), cast(2 as smallint)) FROM t +-- !query 14 schema +struct<(IF(true, CAST(1 AS SMALLINT), CAST(2 AS SMALLINT))):smallint> +-- !query 14 output +1 + + +-- !query 15 +SELECT IF(true, cast(1 as smallint), cast(2 as int)) FROM t +-- !query 15 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS INT), CAST(2 AS INT))):int> +-- !query 15 output +1 + + +-- !query 16 +SELECT IF(true, cast(1 as smallint), cast(2 as bigint)) FROM t +-- !query 16 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS BIGINT), CAST(2 AS BIGINT))):bigint> +-- !query 16 output +1 + + +-- !query 17 +SELECT IF(true, cast(1 as smallint), cast(2 as float)) FROM t +-- !query 17 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS FLOAT), CAST(2 AS FLOAT))):float> +-- !query 17 output +1.0 + + +-- !query 18 +SELECT IF(true, cast(1 as smallint), cast(2 as double)) FROM t +-- !query 18 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 18 output +1.0 + + +-- !query 19 +SELECT IF(true, cast(1 as smallint), cast(2 as decimal(10, 0))) FROM t +-- !query 19 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 19 output +1 + + +-- !query 20 +SELECT IF(true, cast(1 as smallint), cast(2 as string)) FROM t +-- !query 20 schema +struct<(IF(true, CAST(CAST(1 AS SMALLINT) AS STRING), CAST(2 AS STRING))):string> +-- !query 20 output +1 + + +-- !query 21 +SELECT IF(true, cast(1 as smallint), cast('2' as binary)) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2' AS BINARY)))' (smallint and binary).; line 1 pos 7 + + +-- !query 22 +SELECT IF(true, cast(1 as smallint), cast(2 as boolean)) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST(2 AS BOOLEAN)))' (smallint and boolean).; line 1 pos 7 + + +-- !query 23 +SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (smallint and timestamp).; line 1 pos 7 + + +-- !query 24 +SELECT IF(true, cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' (smallint and date).; line 1 pos 7 + + +-- !query 25 +SELECT IF(true, cast(1 as int), cast(2 as tinyint)) FROM t +-- !query 25 schema +struct<(IF(true, CAST(1 AS INT), CAST(CAST(2 AS TINYINT) AS INT))):int> +-- !query 25 output +1 + + +-- !query 26 +SELECT IF(true, cast(1 as int), cast(2 as smallint)) FROM t +-- !query 26 schema +struct<(IF(true, CAST(1 AS INT), CAST(CAST(2 AS SMALLINT) AS INT))):int> +-- !query 26 output +1 + + +-- !query 27 +SELECT IF(true, cast(1 as int), cast(2 as int)) FROM t +-- !query 27 schema +struct<(IF(true, CAST(1 AS INT), CAST(2 AS INT))):int> +-- !query 27 output +1 + + +-- !query 28 +SELECT IF(true, cast(1 as int), cast(2 as bigint)) FROM t +-- !query 28 schema +struct<(IF(true, CAST(CAST(1 AS INT) AS BIGINT), CAST(2 AS BIGINT))):bigint> +-- !query 28 output +1 + + +-- !query 29 +SELECT IF(true, cast(1 as int), cast(2 as float)) FROM t +-- !query 29 schema +struct<(IF(true, CAST(CAST(1 AS INT) AS FLOAT), CAST(2 AS FLOAT))):float> +-- !query 29 output +1.0 + + +-- !query 30 +SELECT IF(true, cast(1 as int), cast(2 as double)) FROM t +-- !query 30 schema +struct<(IF(true, CAST(CAST(1 AS INT) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 30 output +1.0 + + +-- !query 31 +SELECT IF(true, cast(1 as int), cast(2 as decimal(10, 0))) FROM t +-- !query 31 schema +struct<(IF(true, CAST(CAST(1 AS INT) AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 31 output +1 + + +-- !query 32 +SELECT IF(true, cast(1 as int), cast(2 as string)) FROM t +-- !query 32 schema +struct<(IF(true, CAST(CAST(1 AS INT) AS STRING), CAST(2 AS STRING))):string> +-- !query 32 output +1 + + +-- !query 33 +SELECT IF(true, cast(1 as int), cast('2' as binary)) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS INT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2' AS BINARY)))' (int and binary).; line 1 pos 7 + + +-- !query 34 +SELECT IF(true, cast(1 as int), cast(2 as boolean)) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS INT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST(2 AS BOOLEAN)))' (int and boolean).; line 1 pos 7 + + +-- !query 35 +SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (int and timestamp).; line 1 pos 7 + + +-- !query 36 +SELECT IF(true, cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' (int and date).; line 1 pos 7 + + +-- !query 37 +SELECT IF(true, cast(1 as bigint), cast(2 as tinyint)) FROM t +-- !query 37 schema +struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS TINYINT) AS BIGINT))):bigint> +-- !query 37 output +1 + + +-- !query 38 +SELECT IF(true, cast(1 as bigint), cast(2 as smallint)) FROM t +-- !query 38 schema +struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS SMALLINT) AS BIGINT))):bigint> +-- !query 38 output +1 + + +-- !query 39 +SELECT IF(true, cast(1 as bigint), cast(2 as int)) FROM t +-- !query 39 schema +struct<(IF(true, CAST(1 AS BIGINT), CAST(CAST(2 AS INT) AS BIGINT))):bigint> +-- !query 39 output +1 + + +-- !query 40 +SELECT IF(true, cast(1 as bigint), cast(2 as bigint)) FROM t +-- !query 40 schema +struct<(IF(true, CAST(1 AS BIGINT), CAST(2 AS BIGINT))):bigint> +-- !query 40 output +1 + + +-- !query 41 +SELECT IF(true, cast(1 as bigint), cast(2 as float)) FROM t +-- !query 41 schema +struct<(IF(true, CAST(CAST(1 AS BIGINT) AS FLOAT), CAST(2 AS FLOAT))):float> +-- !query 41 output +1.0 + + +-- !query 42 +SELECT IF(true, cast(1 as bigint), cast(2 as double)) FROM t +-- !query 42 schema +struct<(IF(true, CAST(CAST(1 AS BIGINT) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 42 output +1.0 + + +-- !query 43 +SELECT IF(true, cast(1 as bigint), cast(2 as decimal(10, 0))) FROM t +-- !query 43 schema +struct<(IF(true, CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)), CAST(CAST(2 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):decimal(20,0)> +-- !query 43 output +1 + + +-- !query 44 +SELECT IF(true, cast(1 as bigint), cast(2 as string)) FROM t +-- !query 44 schema +struct<(IF(true, CAST(CAST(1 AS BIGINT) AS STRING), CAST(2 AS STRING))):string> +-- !query 44 output +1 + + +-- !query 45 +SELECT IF(true, cast(1 as bigint), cast('2' as binary)) FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2' AS BINARY)))' (bigint and binary).; line 1 pos 7 + + +-- !query 46 +SELECT IF(true, cast(1 as bigint), cast(2 as boolean)) FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST(2 AS BOOLEAN)))' (bigint and boolean).; line 1 pos 7 + + +-- !query 47 +SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (bigint and timestamp).; line 1 pos 7 + + +-- !query 48 +SELECT IF(true, cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' (bigint and date).; line 1 pos 7 + + +-- !query 49 +SELECT IF(true, cast(1 as float), cast(2 as tinyint)) FROM t +-- !query 49 schema +struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS TINYINT) AS FLOAT))):float> +-- !query 49 output +1.0 + + +-- !query 50 +SELECT IF(true, cast(1 as float), cast(2 as smallint)) FROM t +-- !query 50 schema +struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS SMALLINT) AS FLOAT))):float> +-- !query 50 output +1.0 + + +-- !query 51 +SELECT IF(true, cast(1 as float), cast(2 as int)) FROM t +-- !query 51 schema +struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS INT) AS FLOAT))):float> +-- !query 51 output +1.0 + + +-- !query 52 +SELECT IF(true, cast(1 as float), cast(2 as bigint)) FROM t +-- !query 52 schema +struct<(IF(true, CAST(1 AS FLOAT), CAST(CAST(2 AS BIGINT) AS FLOAT))):float> +-- !query 52 output +1.0 + + +-- !query 53 +SELECT IF(true, cast(1 as float), cast(2 as float)) FROM t +-- !query 53 schema +struct<(IF(true, CAST(1 AS FLOAT), CAST(2 AS FLOAT))):float> +-- !query 53 output +1.0 + + +-- !query 54 +SELECT IF(true, cast(1 as float), cast(2 as double)) FROM t +-- !query 54 schema +struct<(IF(true, CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 54 output +1.0 + + +-- !query 55 +SELECT IF(true, cast(1 as float), cast(2 as decimal(10, 0))) FROM t +-- !query 55 schema +struct<(IF(true, CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(CAST(2 AS DECIMAL(10,0)) AS DOUBLE))):double> +-- !query 55 output +1.0 + + +-- !query 56 +SELECT IF(true, cast(1 as float), cast(2 as string)) FROM t +-- !query 56 schema +struct<(IF(true, CAST(CAST(1 AS FLOAT) AS STRING), CAST(2 AS STRING))):string> +-- !query 56 output +1.0 + + +-- !query 57 +SELECT IF(true, cast(1 as float), cast('2' as binary)) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2' AS BINARY)))' (float and binary).; line 1 pos 7 + + +-- !query 58 +SELECT IF(true, cast(1 as float), cast(2 as boolean)) FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST(2 AS BOOLEAN)))' (float and boolean).; line 1 pos 7 + + +-- !query 59 +SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (float and timestamp).; line 1 pos 7 + + +-- !query 60 +SELECT IF(true, cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' (float and date).; line 1 pos 7 + + +-- !query 61 +SELECT IF(true, cast(1 as double), cast(2 as tinyint)) FROM t +-- !query 61 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS TINYINT) AS DOUBLE))):double> +-- !query 61 output +1.0 + + +-- !query 62 +SELECT IF(true, cast(1 as double), cast(2 as smallint)) FROM t +-- !query 62 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS SMALLINT) AS DOUBLE))):double> +-- !query 62 output +1.0 + + +-- !query 63 +SELECT IF(true, cast(1 as double), cast(2 as int)) FROM t +-- !query 63 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS INT) AS DOUBLE))):double> +-- !query 63 output +1.0 + + +-- !query 64 +SELECT IF(true, cast(1 as double), cast(2 as bigint)) FROM t +-- !query 64 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS BIGINT) AS DOUBLE))):double> +-- !query 64 output +1.0 + + +-- !query 65 +SELECT IF(true, cast(1 as double), cast(2 as float)) FROM t +-- !query 65 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS FLOAT) AS DOUBLE))):double> +-- !query 65 output +1.0 + + +-- !query 66 +SELECT IF(true, cast(1 as double), cast(2 as double)) FROM t +-- !query 66 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 66 output +1.0 + + +-- !query 67 +SELECT IF(true, cast(1 as double), cast(2 as decimal(10, 0))) FROM t +-- !query 67 schema +struct<(IF(true, CAST(1 AS DOUBLE), CAST(CAST(2 AS DECIMAL(10,0)) AS DOUBLE))):double> +-- !query 67 output +1.0 + + +-- !query 68 +SELECT IF(true, cast(1 as double), cast(2 as string)) FROM t +-- !query 68 schema +struct<(IF(true, CAST(CAST(1 AS DOUBLE) AS STRING), CAST(2 AS STRING))):string> +-- !query 68 output +1.0 + + +-- !query 69 +SELECT IF(true, cast(1 as double), cast('2' as binary)) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2' AS BINARY)))' (double and binary).; line 1 pos 7 + + +-- !query 70 +SELECT IF(true, cast(1 as double), cast(2 as boolean)) FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST(2 AS BOOLEAN)))' (double and boolean).; line 1 pos 7 + + +-- !query 71 +SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (double and timestamp).; line 1 pos 7 + + +-- !query 72 +SELECT IF(true, cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' (double and date).; line 1 pos 7 + + +-- !query 73 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as tinyint)) FROM t +-- !query 73 schema +struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS TINYINT) AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 73 output +1 + + +-- !query 74 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as smallint)) FROM t +-- !query 74 schema +struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS SMALLINT) AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 74 output +1 + + +-- !query 75 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as int)) FROM t +-- !query 75 schema +struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(CAST(2 AS INT) AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 75 output +1 + + +-- !query 76 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as bigint)) FROM t +-- !query 76 schema +struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)), CAST(CAST(2 AS BIGINT) AS DECIMAL(20,0)))):decimal(20,0)> +-- !query 76 output +1 + + +-- !query 77 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as float)) FROM t +-- !query 77 schema +struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(CAST(2 AS FLOAT) AS DOUBLE))):double> +-- !query 77 output +1.0 + + +-- !query 78 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as double)) FROM t +-- !query 78 schema +struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(2 AS DOUBLE))):double> +-- !query 78 output +1.0 + + +-- !query 79 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as decimal(10, 0))) FROM t +-- !query 79 schema +struct<(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS DECIMAL(10,0)))):decimal(10,0)> +-- !query 79 output +1 + + +-- !query 80 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as string)) FROM t +-- !query 80 schema +struct<(IF(true, CAST(CAST(1 AS DECIMAL(10,0)) AS STRING), CAST(2 AS STRING))):string> +-- !query 80 output +1 + + +-- !query 81 +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2' as binary)) FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2' AS BINARY)))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 82 +SELECT IF(true, cast(1 as decimal(10, 0)), cast(2 as boolean)) FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST(2 AS BOOLEAN)))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 83 +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 84 +SELECT IF(true, cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 85 +SELECT IF(true, cast(1 as string), cast(2 as tinyint)) FROM t +-- !query 85 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS TINYINT) AS STRING))):string> +-- !query 85 output +1 + + +-- !query 86 +SELECT IF(true, cast(1 as string), cast(2 as smallint)) FROM t +-- !query 86 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS SMALLINT) AS STRING))):string> +-- !query 86 output +1 + + +-- !query 87 +SELECT IF(true, cast(1 as string), cast(2 as int)) FROM t +-- !query 87 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS INT) AS STRING))):string> +-- !query 87 output +1 + + +-- !query 88 +SELECT IF(true, cast(1 as string), cast(2 as bigint)) FROM t +-- !query 88 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS BIGINT) AS STRING))):string> +-- !query 88 output +1 + + +-- !query 89 +SELECT IF(true, cast(1 as string), cast(2 as float)) FROM t +-- !query 89 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS FLOAT) AS STRING))):string> +-- !query 89 output +1 + + +-- !query 90 +SELECT IF(true, cast(1 as string), cast(2 as double)) FROM t +-- !query 90 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS DOUBLE) AS STRING))):string> +-- !query 90 output +1 + + +-- !query 91 +SELECT IF(true, cast(1 as string), cast(2 as decimal(10, 0))) FROM t +-- !query 91 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2 AS DECIMAL(10,0)) AS STRING))):string> +-- !query 91 output +1 + + +-- !query 92 +SELECT IF(true, cast(1 as string), cast(2 as string)) FROM t +-- !query 92 schema +struct<(IF(true, CAST(1 AS STRING), CAST(2 AS STRING))):string> +-- !query 92 output +1 + + +-- !query 93 +SELECT IF(true, cast(1 as string), cast('2' as binary)) FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS STRING), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS STRING), CAST('2' AS BINARY)))' (string and binary).; line 1 pos 7 + + +-- !query 94 +SELECT IF(true, cast(1 as string), cast(2 as boolean)) FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS STRING), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS STRING), CAST(2 AS BOOLEAN)))' (string and boolean).; line 1 pos 7 + + +-- !query 95 +SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 95 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING))):string> +-- !query 95 output +1 + + +-- !query 96 +SELECT IF(true, cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 96 schema +struct<(IF(true, CAST(1 AS STRING), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):string> +-- !query 96 output +1 + + +-- !query 97 +SELECT IF(true, cast('1' as binary), cast(2 as tinyint)) FROM t +-- !query 97 schema +struct<> +-- !query 97 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS TINYINT)))' (binary and tinyint).; line 1 pos 7 + + +-- !query 98 +SELECT IF(true, cast('1' as binary), cast(2 as smallint)) FROM t +-- !query 98 schema +struct<> +-- !query 98 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS SMALLINT)))' (binary and smallint).; line 1 pos 7 + + +-- !query 99 +SELECT IF(true, cast('1' as binary), cast(2 as int)) FROM t +-- !query 99 schema +struct<> +-- !query 99 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS INT)))' (binary and int).; line 1 pos 7 + + +-- !query 100 +SELECT IF(true, cast('1' as binary), cast(2 as bigint)) FROM t +-- !query 100 schema +struct<> +-- !query 100 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS BIGINT)))' (binary and bigint).; line 1 pos 7 + + +-- !query 101 +SELECT IF(true, cast('1' as binary), cast(2 as float)) FROM t +-- !query 101 schema +struct<> +-- !query 101 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS FLOAT)))' (binary and float).; line 1 pos 7 + + +-- !query 102 +SELECT IF(true, cast('1' as binary), cast(2 as double)) FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS DOUBLE)))' (binary and double).; line 1 pos 7 + + +-- !query 103 +SELECT IF(true, cast('1' as binary), cast(2 as decimal(10, 0))) FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS DECIMAL(10,0))))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 104 +SELECT IF(true, cast('1' as binary), cast(2 as string)) FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS STRING)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS STRING)))' (binary and string).; line 1 pos 7 + + +-- !query 105 +SELECT IF(true, cast('1' as binary), cast('2' as binary)) FROM t +-- !query 105 schema +struct<(IF(true, CAST(1 AS BINARY), CAST(2 AS BINARY))):binary> +-- !query 105 output +1 + + +-- !query 106 +SELECT IF(true, cast('1' as binary), cast(2 as boolean)) FROM t +-- !query 106 schema +struct<> +-- !query 106 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST(2 AS BOOLEAN)))' (binary and boolean).; line 1 pos 7 + + +-- !query 107 +SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 107 schema +struct<> +-- !query 107 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (binary and timestamp).; line 1 pos 7 + + +-- !query 108 +SELECT IF(true, cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 108 schema +struct<> +-- !query 108 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' (binary and date).; line 1 pos 7 + + +-- !query 109 +SELECT IF(true, cast(1 as boolean), cast(2 as tinyint)) FROM t +-- !query 109 schema +struct<> +-- !query 109 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS TINYINT)))' (boolean and tinyint).; line 1 pos 7 + + +-- !query 110 +SELECT IF(true, cast(1 as boolean), cast(2 as smallint)) FROM t +-- !query 110 schema +struct<> +-- !query 110 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS SMALLINT)))' (boolean and smallint).; line 1 pos 7 + + +-- !query 111 +SELECT IF(true, cast(1 as boolean), cast(2 as int)) FROM t +-- !query 111 schema +struct<> +-- !query 111 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS INT)))' (boolean and int).; line 1 pos 7 + + +-- !query 112 +SELECT IF(true, cast(1 as boolean), cast(2 as bigint)) FROM t +-- !query 112 schema +struct<> +-- !query 112 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BIGINT)))' (boolean and bigint).; line 1 pos 7 + + +-- !query 113 +SELECT IF(true, cast(1 as boolean), cast(2 as float)) FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS FLOAT)))' (boolean and float).; line 1 pos 7 + + +-- !query 114 +SELECT IF(true, cast(1 as boolean), cast(2 as double)) FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DOUBLE)))' (boolean and double).; line 1 pos 7 + + +-- !query 115 +SELECT IF(true, cast(1 as boolean), cast(2 as decimal(10, 0))) FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS DECIMAL(10,0))))' (boolean and decimal(10,0)).; line 1 pos 7 + + +-- !query 116 +SELECT IF(true, cast(1 as boolean), cast(2 as string)) FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS STRING)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS STRING)))' (boolean and string).; line 1 pos 7 + + +-- !query 117 +SELECT IF(true, cast(1 as boolean), cast('2' as binary)) FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2' AS BINARY)))' (boolean and binary).; line 1 pos 7 + + +-- !query 118 +SELECT IF(true, cast(1 as boolean), cast(2 as boolean)) FROM t +-- !query 118 schema +struct<(IF(true, CAST(1 AS BOOLEAN), CAST(2 AS BOOLEAN))):boolean> +-- !query 118 output +true + + +-- !query 119 +SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 120 +SELECT IF(true, cast(1 as boolean), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: differing types in '(IF(true, CAST(1 AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' (boolean and date).; line 1 pos 7 + + +-- !query 121 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as tinyint)) FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS TINYINT)))' (timestamp and tinyint).; line 1 pos 7 + + +-- !query 122 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as smallint)) FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS SMALLINT)))' (timestamp and smallint).; line 1 pos 7 + + +-- !query 123 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as int)) FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS INT)))' (timestamp and int).; line 1 pos 7 + + +-- !query 124 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as bigint)) FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BIGINT)))' (timestamp and bigint).; line 1 pos 7 + + +-- !query 125 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as float)) FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS FLOAT)))' (timestamp and float).; line 1 pos 7 + + +-- !query 126 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as double)) FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DOUBLE)))' (timestamp and double).; line 1 pos 7 + + +-- !query 127 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as decimal(10, 0))) FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS DECIMAL(10,0))))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 128 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as string)) FROM t +-- !query 128 schema +struct<(IF(true, CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING), CAST(2 AS STRING))):string> +-- !query 128 output +2017-12-12 09:30:00 + + +-- !query 129 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2' as binary)) FROM t +-- !query 129 schema +struct<> +-- !query 129 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('2' AS BINARY)))' (timestamp and binary).; line 1 pos 7 + + +-- !query 130 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast(2 as boolean)) FROM t +-- !query 130 schema +struct<> +-- !query 130 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(2 AS BOOLEAN)))' (timestamp and boolean).; line 1 pos 7 + + +-- !query 131 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 131 schema +struct<(IF(true, CAST(2017-12-12 09:30:00.0 AS TIMESTAMP), CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):timestamp> +-- !query 131 output +2017-12-12 09:30:00 + + +-- !query 132 +SELECT IF(true, cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 132 schema +struct<(IF(true, CAST(2017-12-12 09:30:00.0 AS TIMESTAMP), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS TIMESTAMP))):timestamp> +-- !query 132 output +2017-12-12 09:30:00 + + +-- !query 133 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as tinyint)) FROM t +-- !query 133 schema +struct<> +-- !query 133 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS TINYINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS TINYINT)))' (date and tinyint).; line 1 pos 7 + + +-- !query 134 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as smallint)) FROM t +-- !query 134 schema +struct<> +-- !query 134 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS SMALLINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS SMALLINT)))' (date and smallint).; line 1 pos 7 + + +-- !query 135 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as int)) FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS INT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS INT)))' (date and int).; line 1 pos 7 + + +-- !query 136 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as bigint)) FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BIGINT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BIGINT)))' (date and bigint).; line 1 pos 7 + + +-- !query 137 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as float)) FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS FLOAT)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS FLOAT)))' (date and float).; line 1 pos 7 + + +-- !query 138 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as double)) FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DOUBLE)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DOUBLE)))' (date and double).; line 1 pos 7 + + +-- !query 139 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as decimal(10, 0))) FROM t +-- !query 139 schema +struct<> +-- !query 139 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS DECIMAL(10,0))))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 140 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as string)) FROM t +-- !query 140 schema +struct<(IF(true, CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING), CAST(2 AS STRING))):string> +-- !query 140 output +2017-12-12 + + +-- !query 141 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2' as binary)) FROM t +-- !query 141 schema +struct<> +-- !query 141 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST('2' AS BINARY)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST('2' AS BINARY)))' (date and binary).; line 1 pos 7 + + +-- !query 142 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast(2 as boolean)) FROM t +-- !query 142 schema +struct<> +-- !query 142 output +org.apache.spark.sql.AnalysisException +cannot resolve '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BOOLEAN)))' due to data type mismatch: differing types in '(IF(true, CAST('2017-12-12 09:30:00' AS DATE), CAST(2 AS BOOLEAN)))' (date and boolean).; line 1 pos 7 + + +-- !query 143 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 143 schema +struct<(IF(true, CAST(CAST(2017-12-12 09:30:00 AS DATE) AS TIMESTAMP), CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):timestamp> +-- !query 143 output +2017-12-12 00:00:00 + + +-- !query 144 +SELECT IF(true, cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 144 schema +struct<(IF(true, CAST(2017-12-12 09:30:00 AS DATE), CAST(2017-12-11 09:30:00 AS DATE))):date> +-- !query 144 output +2017-12-12 From 0c8fca4608643ed9e1eb3ae8620e6f4f6a017a87 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Sat, 16 Dec 2017 10:57:35 +0900 Subject: [PATCH 440/712] [SPARK-22811][PYSPARK][ML] Fix pyspark.ml.tests failure when Hive is not available. ## What changes were proposed in this pull request? pyspark.ml.tests is missing a py4j import. I've added the import and fixed the test that uses it. This test was only failing when testing without Hive. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #19997 from MrBago/fix-ImageReaderTest2. --- python/pyspark/ml/tests.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3a0b816c367ec..be1521154f042 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,6 +44,7 @@ import numpy as np from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect +import py4j from pyspark import keyword_only, SparkContext from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer @@ -1859,8 +1860,9 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): - PySparkTestCase.setUpClass() + super(ImageReaderTest2, cls).setUpClass() # Note that here we enable Hive's support. + cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: @@ -1873,8 +1875,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.sparkSession.stop() + super(ImageReaderTest2, cls).tearDownClass() + if cls.spark is not None: + cls.spark.sparkSession.stop() + cls.spark = None def test_read_images_multiple_times(self): # This test case is to check if `ImageSchema.readImages` tries to From c2aeddf9eae2f8f72c244a4b16af264362d6cf5d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 17 Dec 2017 14:40:41 +0900 Subject: [PATCH 441/712] [SPARK-22817][R] Use fixed testthat version for SparkR tests in AppVeyor ## What changes were proposed in this pull request? `testthat` 2.0.0 is released and AppVeyor now started to use it instead of 1.0.2. And then, we started to have R tests failed in AppVeyor. See - https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark/build/1967-master ``` Error in get(name, envir = asNamespace(pkg), inherits = FALSE) : object 'run_tests' not found Calls: ::: -> get ``` This seems because we rely on internal `testthat:::run_tests` here: https://github.com/r-lib/testthat/blob/v1.0.2/R/test-package.R#L62-L75 https://github.com/apache/spark/blob/dc4c351837879dab26ad8fb471dc51c06832a9e4/R/pkg/tests/run-all.R#L49-L52 However, seems it was removed out from 2.0.0. I tried few other exposed APIs like `test_dir` but I failed to make a good compatible fix. Seems we better fix the `testthat` version first to make the build passed. ## How was this patch tested? Manually tested and AppVeyor tests. Author: hyukjinkwon Closes #20003 from HyukjinKwon/SPARK-22817. --- appveyor.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 48740920cd09b..aee94c59612d2 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -42,7 +42,9 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'devtools', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + # Here, we use the fixed version of testthat. For more details, please see SPARK-22817. + - cmd: R -e "devtools::install_version('testthat', version = '1.0.2', repos='http://cran.us.r-project.org')" - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: From 77988a9d0d553f26034f7206e5e6314acab2dec5 Mon Sep 17 00:00:00 2001 From: Mahmut CAVDAR Date: Sun, 17 Dec 2017 10:52:01 -0600 Subject: [PATCH 442/712] [MINOR][DOC] Fix the link of 'Getting Started' ## What changes were proposed in this pull request? Easy fix in the link. ## How was this patch tested? Tested manually Author: Mahmut CAVDAR Closes #19996 from mcavdar/master. --- docs/mllib-decision-tree.md | 2 +- docs/running-on-mesos.md | 4 ++-- docs/spark-standalone.md | 2 +- docs/sql-programming-guide.md | 1 + docs/tuning.md | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 0e753b8dd04a2..ec13b81f85557 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -91,7 +91,7 @@ For a categorical feature with `$M$` possible values (categories), one could com `$2^{M-1}-1$` split candidates. For binary (0/1) classification and regression, we can reduce the number of split candidates to `$M-1$` by ordering the categorical feature values by the average label. (See Section 9.2.4 in -[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +[Elements of Statistical Machine Learning](https://web.stanford.edu/~hastie/ElemStatLearn/) for details.) For example, for a binary classification problem with one categorical feature with three categories A, B and C whose corresponding proportions of label 1 are 0.2, 0.6 and 0.4, the categorical features are ordered as A, C, B. The two split candidates are A \| C, B diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 19ec7c1e0aeee..382cbfd5301b0 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -47,7 +47,7 @@ To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) -2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and +2. Follow the Mesos [Getting Started](http://mesos.apache.org/getting-started) page for compiling and installing Mesos **Note:** If you want to run Mesos without installing it into the default paths on your system @@ -159,7 +159,7 @@ By setting the Mesos proxy config property (requires mesos version >= 1.4), `--c If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations [doc](configurations.html#deploy). +For more information about these configurations please refer to the configurations [doc](configuration.html#deploy). You can also specify any additional jars required by the `MesosClusterDispatcher` in the classpath by setting the environment variable SPARK_DAEMON_CLASSPATH in spark-env. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f51c5cc38f4de..8fa643abf1373 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -364,7 +364,7 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. -Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/trunk/zookeeperStarted.html). +Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/current/zookeeperStarted.html). **Configuration** diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b76be9132dd03..f02f46236e2b0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -501,6 +501,7 @@ To load a CSV file you can use: + ### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that diff --git a/docs/tuning.md b/docs/tuning.md index 7d5f97a02fe6e..fc27713f28d46 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -219,7 +219,7 @@ temporary objects created during task execution. Some steps which may be useful * Try the G1GC garbage collector with `-XX:+UseG1GC`. It can improve performance in some situations where garbage collection is a bottleneck. Note that with large executor heap sizes, it may be important to - increase the [G1 region size](https://blogs.oracle.com/g1gc/entry/g1_gc_tuning_a_case) + increase the [G1 region size](http://www.oracle.com/technetwork/articles/java/g1gc-1984535.html) with `-XX:G1HeapRegionSize` * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using From 7f6d10a7376594cf7e5225a05fa9e58b2011ad3d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 17 Dec 2017 09:15:10 -0800 Subject: [PATCH 443/712] [SPARK-22816][TEST] Basic tests for PromoteStrings and InConversion ## What changes were proposed in this pull request? Test Coverage for `PromoteStrings` and `InConversion`, this is a Sub-tasks for [SPARK-22722](https://issues.apache.org/jira/browse/SPARK-22722). ## How was this patch tested? N/A Author: Yuming Wang Closes #20001 from wangyum/SPARK-22816. --- .../typeCoercion/native/inConversion.sql | 330 +++ .../typeCoercion/native/promoteStrings.sql | 364 +++ .../typeCoercion/native/inConversion.sql.out | 2454 ++++++++++++++++ .../native/promoteStrings.sql.out | 2578 +++++++++++++++++ 4 files changed, 5726 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/promoteStrings.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql new file mode 100644 index 0000000000000..39dbe7268fba0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/inConversion.sql @@ -0,0 +1,330 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT cast(1 as tinyint) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as int)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as float)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as double)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as string)) FROM t; +SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as smallint) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as int)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as float)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as double)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as smallint) in (cast(1 as string)) FROM t; +SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as int) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as int) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as int) in (cast(1 as int)) FROM t; +SELECT cast(1 as int) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as int) in (cast(1 as float)) FROM t; +SELECT cast(1 as int) in (cast(1 as double)) FROM t; +SELECT cast(1 as int) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as int) in (cast(1 as string)) FROM t; +SELECT cast(1 as int) in (cast('1' as binary)) FROM t; +SELECT cast(1 as int) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as bigint) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as int)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as float)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as double)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as bigint) in (cast(1 as string)) FROM t; +SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as float) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as float) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as float) in (cast(1 as int)) FROM t; +SELECT cast(1 as float) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as float) in (cast(1 as float)) FROM t; +SELECT cast(1 as float) in (cast(1 as double)) FROM t; +SELECT cast(1 as float) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as float) in (cast(1 as string)) FROM t; +SELECT cast(1 as float) in (cast('1' as binary)) FROM t; +SELECT cast(1 as float) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as double) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as double) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as double) in (cast(1 as int)) FROM t; +SELECT cast(1 as double) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as double) in (cast(1 as float)) FROM t; +SELECT cast(1 as double) in (cast(1 as double)) FROM t; +SELECT cast(1 as double) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as double) in (cast(1 as string)) FROM t; +SELECT cast(1 as double) in (cast('1' as binary)) FROM t; +SELECT cast(1 as double) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as decimal(10, 0)) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as int)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as float)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as double)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as string)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as string) in (cast(1 as tinyint)) FROM t; +SELECT cast(1 as string) in (cast(1 as smallint)) FROM t; +SELECT cast(1 as string) in (cast(1 as int)) FROM t; +SELECT cast(1 as string) in (cast(1 as bigint)) FROM t; +SELECT cast(1 as string) in (cast(1 as float)) FROM t; +SELECT cast(1 as string) in (cast(1 as double)) FROM t; +SELECT cast(1 as string) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as string) in (cast(1 as string)) FROM t; +SELECT cast(1 as string) in (cast('1' as binary)) FROM t; +SELECT cast(1 as string) in (cast(1 as boolean)) FROM t; +SELECT cast(1 as string) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as string) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t; +SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t; +SELECT cast('1' as binary) in (cast(1 as int)) FROM t; +SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t; +SELECT cast('1' as binary) in (cast(1 as float)) FROM t; +SELECT cast('1' as binary) in (cast(1 as double)) FROM t; +SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t; +SELECT cast('1' as binary) in (cast(1 as string)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary)) FROM t; +SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t; +SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT true in (cast(1 as tinyint)) FROM t; +SELECT true in (cast(1 as smallint)) FROM t; +SELECT true in (cast(1 as int)) FROM t; +SELECT true in (cast(1 as bigint)) FROM t; +SELECT true in (cast(1 as float)) FROM t; +SELECT true in (cast(1 as double)) FROM t; +SELECT true in (cast(1 as decimal(10, 0))) FROM t; +SELECT true in (cast(1 as string)) FROM t; +SELECT true in (cast('1' as binary)) FROM t; +SELECT true in (cast(1 as boolean)) FROM t; +SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as string)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as string)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as tinyint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as smallint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as int)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as bigint)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as float)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as double)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as string)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as tinyint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as smallint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as int)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as bigint)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as float)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as double)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as string)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as int) in (cast(1 as int), cast(1 as tinyint)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as smallint)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as int)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as bigint)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as float)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as double)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as string)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as tinyint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as smallint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as int)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as bigint)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as float)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as double)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as string)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as float) in (cast(1 as float), cast(1 as tinyint)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as smallint)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as int)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as bigint)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as float)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as double)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as string)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as double) in (cast(1 as double), cast(1 as tinyint)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as smallint)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as int)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as bigint)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as float)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as double)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as string)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as int)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as float)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as double)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as string)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as string) in (cast(1 as string), cast(1 as tinyint)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as smallint)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as int)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as bigint)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as float)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as double)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as decimal(10, 0))) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as string)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast('1' as binary)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as boolean)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as tinyint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as smallint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as int)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as bigint)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as float)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as double)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as string)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('1' as binary)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as boolean)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as tinyint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as smallint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as int)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as bigint)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as float)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as double)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as string)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('1' as binary)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as boolean)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql new file mode 100644 index 0000000000000..a5603a184578d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/promoteStrings.sql @@ -0,0 +1,364 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +-- Binary arithmetic +SELECT '1' + cast(1 as tinyint) FROM t; +SELECT '1' + cast(1 as smallint) FROM t; +SELECT '1' + cast(1 as int) FROM t; +SELECT '1' + cast(1 as bigint) FROM t; +SELECT '1' + cast(1 as float) FROM t; +SELECT '1' + cast(1 as double) FROM t; +SELECT '1' + cast(1 as decimal(10, 0)) FROM t; +SELECT '1' + '1' FROM t; +SELECT '1' + cast('1' as binary) FROM t; +SELECT '1' + cast(1 as boolean) FROM t; +SELECT '1' + cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' + cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' - cast(1 as tinyint) FROM t; +SELECT '1' - cast(1 as smallint) FROM t; +SELECT '1' - cast(1 as int) FROM t; +SELECT '1' - cast(1 as bigint) FROM t; +SELECT '1' - cast(1 as float) FROM t; +SELECT '1' - cast(1 as double) FROM t; +SELECT '1' - cast(1 as decimal(10, 0)) FROM t; +SELECT '1' - '1' FROM t; +SELECT '1' - cast('1' as binary) FROM t; +SELECT '1' - cast(1 as boolean) FROM t; +SELECT '1' - cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' - cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' * cast(1 as tinyint) FROM t; +SELECT '1' * cast(1 as smallint) FROM t; +SELECT '1' * cast(1 as int) FROM t; +SELECT '1' * cast(1 as bigint) FROM t; +SELECT '1' * cast(1 as float) FROM t; +SELECT '1' * cast(1 as double) FROM t; +SELECT '1' * cast(1 as decimal(10, 0)) FROM t; +SELECT '1' * '1' FROM t; +SELECT '1' * cast('1' as binary) FROM t; +SELECT '1' * cast(1 as boolean) FROM t; +SELECT '1' * cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' * cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' / cast(1 as tinyint) FROM t; +SELECT '1' / cast(1 as smallint) FROM t; +SELECT '1' / cast(1 as int) FROM t; +SELECT '1' / cast(1 as bigint) FROM t; +SELECT '1' / cast(1 as float) FROM t; +SELECT '1' / cast(1 as double) FROM t; +SELECT '1' / cast(1 as decimal(10, 0)) FROM t; +SELECT '1' / '1' FROM t; +SELECT '1' / cast('1' as binary) FROM t; +SELECT '1' / cast(1 as boolean) FROM t; +SELECT '1' / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' % cast(1 as tinyint) FROM t; +SELECT '1' % cast(1 as smallint) FROM t; +SELECT '1' % cast(1 as int) FROM t; +SELECT '1' % cast(1 as bigint) FROM t; +SELECT '1' % cast(1 as float) FROM t; +SELECT '1' % cast(1 as double) FROM t; +SELECT '1' % cast(1 as decimal(10, 0)) FROM t; +SELECT '1' % '1' FROM t; +SELECT '1' % cast('1' as binary) FROM t; +SELECT '1' % cast(1 as boolean) FROM t; +SELECT '1' % cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' % cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT pmod('1', cast(1 as tinyint)) FROM t; +SELECT pmod('1', cast(1 as smallint)) FROM t; +SELECT pmod('1', cast(1 as int)) FROM t; +SELECT pmod('1', cast(1 as bigint)) FROM t; +SELECT pmod('1', cast(1 as float)) FROM t; +SELECT pmod('1', cast(1 as double)) FROM t; +SELECT pmod('1', cast(1 as decimal(10, 0))) FROM t; +SELECT pmod('1', '1') FROM t; +SELECT pmod('1', cast('1' as binary)) FROM t; +SELECT pmod('1', cast(1 as boolean)) FROM t; +SELECT pmod('1', cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT pmod('1', cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as tinyint) + '1' FROM t; +SELECT cast(1 as smallint) + '1' FROM t; +SELECT cast(1 as int) + '1' FROM t; +SELECT cast(1 as bigint) + '1' FROM t; +SELECT cast(1 as float) + '1' FROM t; +SELECT cast(1 as double) + '1' FROM t; +SELECT cast(1 as decimal(10, 0)) + '1' FROM t; +SELECT cast('1' as binary) + '1' FROM t; +SELECT cast(1 as boolean) + '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) + '1' FROM t; + +SELECT cast(1 as tinyint) - '1' FROM t; +SELECT cast(1 as smallint) - '1' FROM t; +SELECT cast(1 as int) - '1' FROM t; +SELECT cast(1 as bigint) - '1' FROM t; +SELECT cast(1 as float) - '1' FROM t; +SELECT cast(1 as double) - '1' FROM t; +SELECT cast(1 as decimal(10, 0)) - '1' FROM t; +SELECT cast('1' as binary) - '1' FROM t; +SELECT cast(1 as boolean) - '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) - '1' FROM t; + +SELECT cast(1 as tinyint) * '1' FROM t; +SELECT cast(1 as smallint) * '1' FROM t; +SELECT cast(1 as int) * '1' FROM t; +SELECT cast(1 as bigint) * '1' FROM t; +SELECT cast(1 as float) * '1' FROM t; +SELECT cast(1 as double) * '1' FROM t; +SELECT cast(1 as decimal(10, 0)) * '1' FROM t; +SELECT cast('1' as binary) * '1' FROM t; +SELECT cast(1 as boolean) * '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) * '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) * '1' FROM t; + +SELECT cast(1 as tinyint) / '1' FROM t; +SELECT cast(1 as smallint) / '1' FROM t; +SELECT cast(1 as int) / '1' FROM t; +SELECT cast(1 as bigint) / '1' FROM t; +SELECT cast(1 as float) / '1' FROM t; +SELECT cast(1 as double) / '1' FROM t; +SELECT cast(1 as decimal(10, 0)) / '1' FROM t; +SELECT cast('1' as binary) / '1' FROM t; +SELECT cast(1 as boolean) / '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / '1' FROM t; + +SELECT cast(1 as tinyint) % '1' FROM t; +SELECT cast(1 as smallint) % '1' FROM t; +SELECT cast(1 as int) % '1' FROM t; +SELECT cast(1 as bigint) % '1' FROM t; +SELECT cast(1 as float) % '1' FROM t; +SELECT cast(1 as double) % '1' FROM t; +SELECT cast(1 as decimal(10, 0)) % '1' FROM t; +SELECT cast('1' as binary) % '1' FROM t; +SELECT cast(1 as boolean) % '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) % '1' FROM t; + +SELECT pmod(cast(1 as tinyint), '1') FROM t; +SELECT pmod(cast(1 as smallint), '1') FROM t; +SELECT pmod(cast(1 as int), '1') FROM t; +SELECT pmod(cast(1 as bigint), '1') FROM t; +SELECT pmod(cast(1 as float), '1') FROM t; +SELECT pmod(cast(1 as double), '1') FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), '1') FROM t; +SELECT pmod(cast('1' as binary), '1') FROM t; +SELECT pmod(cast(1 as boolean), '1') FROM t; +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), '1') FROM t; +SELECT pmod(cast('2017-12-11 09:30:00' as date), '1') FROM t; + +-- Equality +SELECT '1' = cast(1 as tinyint) FROM t; +SELECT '1' = cast(1 as smallint) FROM t; +SELECT '1' = cast(1 as int) FROM t; +SELECT '1' = cast(1 as bigint) FROM t; +SELECT '1' = cast(1 as float) FROM t; +SELECT '1' = cast(1 as double) FROM t; +SELECT '1' = cast(1 as decimal(10, 0)) FROM t; +SELECT '1' = '1' FROM t; +SELECT '1' = cast('1' as binary) FROM t; +SELECT '1' = cast(1 as boolean) FROM t; +SELECT '1' = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' = cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) = '1' FROM t; +SELECT cast(1 as smallint) = '1' FROM t; +SELECT cast(1 as int) = '1' FROM t; +SELECT cast(1 as bigint) = '1' FROM t; +SELECT cast(1 as float) = '1' FROM t; +SELECT cast(1 as double) = '1' FROM t; +SELECT cast(1 as decimal(10, 0)) = '1' FROM t; +SELECT cast('1' as binary) = '1' FROM t; +SELECT cast(1 as boolean) = '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = '1' FROM t; + +SELECT '1' <=> cast(1 as tinyint) FROM t; +SELECT '1' <=> cast(1 as smallint) FROM t; +SELECT '1' <=> cast(1 as int) FROM t; +SELECT '1' <=> cast(1 as bigint) FROM t; +SELECT '1' <=> cast(1 as float) FROM t; +SELECT '1' <=> cast(1 as double) FROM t; +SELECT '1' <=> cast(1 as decimal(10, 0)) FROM t; +SELECT '1' <=> '1' FROM t; +SELECT '1' <=> cast('1' as binary) FROM t; +SELECT '1' <=> cast(1 as boolean) FROM t; +SELECT '1' <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' <=> cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) <=> '1' FROM t; +SELECT cast(1 as smallint) <=> '1' FROM t; +SELECT cast(1 as int) <=> '1' FROM t; +SELECT cast(1 as bigint) <=> '1' FROM t; +SELECT cast(1 as float) <=> '1' FROM t; +SELECT cast(1 as double) <=> '1' FROM t; +SELECT cast(1 as decimal(10, 0)) <=> '1' FROM t; +SELECT cast('1' as binary) <=> '1' FROM t; +SELECT cast(1 as boolean) <=> '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> '1' FROM t; + +-- Binary comparison +SELECT '1' < cast(1 as tinyint) FROM t; +SELECT '1' < cast(1 as smallint) FROM t; +SELECT '1' < cast(1 as int) FROM t; +SELECT '1' < cast(1 as bigint) FROM t; +SELECT '1' < cast(1 as float) FROM t; +SELECT '1' < cast(1 as double) FROM t; +SELECT '1' < cast(1 as decimal(10, 0)) FROM t; +SELECT '1' < '1' FROM t; +SELECT '1' < cast('1' as binary) FROM t; +SELECT '1' < cast(1 as boolean) FROM t; +SELECT '1' < cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' < cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' <= cast(1 as tinyint) FROM t; +SELECT '1' <= cast(1 as smallint) FROM t; +SELECT '1' <= cast(1 as int) FROM t; +SELECT '1' <= cast(1 as bigint) FROM t; +SELECT '1' <= cast(1 as float) FROM t; +SELECT '1' <= cast(1 as double) FROM t; +SELECT '1' <= cast(1 as decimal(10, 0)) FROM t; +SELECT '1' <= '1' FROM t; +SELECT '1' <= cast('1' as binary) FROM t; +SELECT '1' <= cast(1 as boolean) FROM t; +SELECT '1' <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' <= cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' > cast(1 as tinyint) FROM t; +SELECT '1' > cast(1 as smallint) FROM t; +SELECT '1' > cast(1 as int) FROM t; +SELECT '1' > cast(1 as bigint) FROM t; +SELECT '1' > cast(1 as float) FROM t; +SELECT '1' > cast(1 as double) FROM t; +SELECT '1' > cast(1 as decimal(10, 0)) FROM t; +SELECT '1' > '1' FROM t; +SELECT '1' > cast('1' as binary) FROM t; +SELECT '1' > cast(1 as boolean) FROM t; +SELECT '1' > cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' > cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' >= cast(1 as tinyint) FROM t; +SELECT '1' >= cast(1 as smallint) FROM t; +SELECT '1' >= cast(1 as int) FROM t; +SELECT '1' >= cast(1 as bigint) FROM t; +SELECT '1' >= cast(1 as float) FROM t; +SELECT '1' >= cast(1 as double) FROM t; +SELECT '1' >= cast(1 as decimal(10, 0)) FROM t; +SELECT '1' >= '1' FROM t; +SELECT '1' >= cast('1' as binary) FROM t; +SELECT '1' >= cast(1 as boolean) FROM t; +SELECT '1' >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' >= cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT '1' <> cast(1 as tinyint) FROM t; +SELECT '1' <> cast(1 as smallint) FROM t; +SELECT '1' <> cast(1 as int) FROM t; +SELECT '1' <> cast(1 as bigint) FROM t; +SELECT '1' <> cast(1 as float) FROM t; +SELECT '1' <> cast(1 as double) FROM t; +SELECT '1' <> cast(1 as decimal(10, 0)) FROM t; +SELECT '1' <> '1' FROM t; +SELECT '1' <> cast('1' as binary) FROM t; +SELECT '1' <> cast(1 as boolean) FROM t; +SELECT '1' <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT '1' <> cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) < '1' FROM t; +SELECT cast(1 as smallint) < '1' FROM t; +SELECT cast(1 as int) < '1' FROM t; +SELECT cast(1 as bigint) < '1' FROM t; +SELECT cast(1 as float) < '1' FROM t; +SELECT cast(1 as double) < '1' FROM t; +SELECT cast(1 as decimal(10, 0)) < '1' FROM t; +SELECT '1' < '1' FROM t; +SELECT cast('1' as binary) < '1' FROM t; +SELECT cast(1 as boolean) < '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) < '1' FROM t; + +SELECT cast(1 as tinyint) <= '1' FROM t; +SELECT cast(1 as smallint) <= '1' FROM t; +SELECT cast(1 as int) <= '1' FROM t; +SELECT cast(1 as bigint) <= '1' FROM t; +SELECT cast(1 as float) <= '1' FROM t; +SELECT cast(1 as double) <= '1' FROM t; +SELECT cast(1 as decimal(10, 0)) <= '1' FROM t; +SELECT '1' <= '1' FROM t; +SELECT cast('1' as binary) <= '1' FROM t; +SELECT cast(1 as boolean) <= '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <= '1' FROM t; + +SELECT cast(1 as tinyint) > '1' FROM t; +SELECT cast(1 as smallint) > '1' FROM t; +SELECT cast(1 as int) > '1' FROM t; +SELECT cast(1 as bigint) > '1' FROM t; +SELECT cast(1 as float) > '1' FROM t; +SELECT cast(1 as double) > '1' FROM t; +SELECT cast(1 as decimal(10, 0)) > '1' FROM t; +SELECT '1' > '1' FROM t; +SELECT cast('1' as binary) > '1' FROM t; +SELECT cast(1 as boolean) > '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) > '1' FROM t; + +SELECT cast(1 as tinyint) >= '1' FROM t; +SELECT cast(1 as smallint) >= '1' FROM t; +SELECT cast(1 as int) >= '1' FROM t; +SELECT cast(1 as bigint) >= '1' FROM t; +SELECT cast(1 as float) >= '1' FROM t; +SELECT cast(1 as double) >= '1' FROM t; +SELECT cast(1 as decimal(10, 0)) >= '1' FROM t; +SELECT '1' >= '1' FROM t; +SELECT cast('1' as binary) >= '1' FROM t; +SELECT cast(1 as boolean) >= '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) >= '1' FROM t; + +SELECT cast(1 as tinyint) <> '1' FROM t; +SELECT cast(1 as smallint) <> '1' FROM t; +SELECT cast(1 as int) <> '1' FROM t; +SELECT cast(1 as bigint) <> '1' FROM t; +SELECT cast(1 as float) <> '1' FROM t; +SELECT cast(1 as double) <> '1' FROM t; +SELECT cast(1 as decimal(10, 0)) <> '1' FROM t; +SELECT '1' <> '1' FROM t; +SELECT cast('1' as binary) <> '1' FROM t; +SELECT cast(1 as boolean) <> '1' FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> '1' FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <> '1' FROM t; + +-- Functions +SELECT abs('1') FROM t; +SELECT sum('1') FROM t; +SELECT avg('1') FROM t; +SELECT stddev_pop('1') FROM t; +SELECT stddev_samp('1') FROM t; +SELECT - '1' FROM t; +SELECT + '1' FROM t; +SELECT var_pop('1') FROM t; +SELECT var_samp('1') FROM t; +SELECT skewness('1') FROM t; +SELECT kurtosis('1') FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out new file mode 100644 index 0000000000000..bf8ddee89b798 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out @@ -0,0 +1,2454 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 289 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT cast(1 as tinyint) in (cast(1 as tinyint)) FROM t +-- !query 1 schema +struct<(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT))):boolean> +-- !query 1 output +true + + +-- !query 2 +SELECT cast(1 as tinyint) in (cast(1 as smallint)) FROM t +-- !query 2 schema +struct<(CAST(CAST(1 AS TINYINT) AS SMALLINT) IN (CAST(CAST(1 AS SMALLINT) AS SMALLINT))):boolean> +-- !query 2 output +true + + +-- !query 3 +SELECT cast(1 as tinyint) in (cast(1 as int)) FROM t +-- !query 3 schema +struct<(CAST(CAST(1 AS TINYINT) AS INT) IN (CAST(CAST(1 AS INT) AS INT))):boolean> +-- !query 3 output +true + + +-- !query 4 +SELECT cast(1 as tinyint) in (cast(1 as bigint)) FROM t +-- !query 4 schema +struct<(CAST(CAST(1 AS TINYINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 4 output +true + + +-- !query 5 +SELECT cast(1 as tinyint) in (cast(1 as float)) FROM t +-- !query 5 schema +struct<(CAST(CAST(1 AS TINYINT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 5 output +true + + +-- !query 6 +SELECT cast(1 as tinyint) in (cast(1 as double)) FROM t +-- !query 6 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 6 output +true + + +-- !query 7 +SELECT cast(1 as tinyint) in (cast(1 as decimal(10, 0))) FROM t +-- !query 7 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 7 output +true + + +-- !query 8 +SELECT cast(1 as tinyint) in (cast(1 as string)) FROM t +-- !query 8 schema +struct<(CAST(CAST(1 AS TINYINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 8 output +true + + +-- !query 9 +SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 + + +-- !query 10 +SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 + + +-- !query 11 +SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 + + +-- !query 12 +SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 + + +-- !query 13 +SELECT cast(1 as smallint) in (cast(1 as tinyint)) FROM t +-- !query 13 schema +struct<(CAST(CAST(1 AS SMALLINT) AS SMALLINT) IN (CAST(CAST(1 AS TINYINT) AS SMALLINT))):boolean> +-- !query 13 output +true + + +-- !query 14 +SELECT cast(1 as smallint) in (cast(1 as smallint)) FROM t +-- !query 14 schema +struct<(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT))):boolean> +-- !query 14 output +true + + +-- !query 15 +SELECT cast(1 as smallint) in (cast(1 as int)) FROM t +-- !query 15 schema +struct<(CAST(CAST(1 AS SMALLINT) AS INT) IN (CAST(CAST(1 AS INT) AS INT))):boolean> +-- !query 15 output +true + + +-- !query 16 +SELECT cast(1 as smallint) in (cast(1 as bigint)) FROM t +-- !query 16 schema +struct<(CAST(CAST(1 AS SMALLINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 16 output +true + + +-- !query 17 +SELECT cast(1 as smallint) in (cast(1 as float)) FROM t +-- !query 17 schema +struct<(CAST(CAST(1 AS SMALLINT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 17 output +true + + +-- !query 18 +SELECT cast(1 as smallint) in (cast(1 as double)) FROM t +-- !query 18 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 18 output +true + + +-- !query 19 +SELECT cast(1 as smallint) in (cast(1 as decimal(10, 0))) FROM t +-- !query 19 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 19 output +true + + +-- !query 20 +SELECT cast(1 as smallint) in (cast(1 as string)) FROM t +-- !query 20 schema +struct<(CAST(CAST(1 AS SMALLINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 20 output +true + + +-- !query 21 +SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 + + +-- !query 22 +SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 + + +-- !query 23 +SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 + + +-- !query 24 +SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 + + +-- !query 25 +SELECT cast(1 as int) in (cast(1 as tinyint)) FROM t +-- !query 25 schema +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS TINYINT) AS INT))):boolean> +-- !query 25 output +true + + +-- !query 26 +SELECT cast(1 as int) in (cast(1 as smallint)) FROM t +-- !query 26 schema +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS SMALLINT) AS INT))):boolean> +-- !query 26 output +true + + +-- !query 27 +SELECT cast(1 as int) in (cast(1 as int)) FROM t +-- !query 27 schema +struct<(CAST(1 AS INT) IN (CAST(1 AS INT))):boolean> +-- !query 27 output +true + + +-- !query 28 +SELECT cast(1 as int) in (cast(1 as bigint)) FROM t +-- !query 28 schema +struct<(CAST(CAST(1 AS INT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 28 output +true + + +-- !query 29 +SELECT cast(1 as int) in (cast(1 as float)) FROM t +-- !query 29 schema +struct<(CAST(CAST(1 AS INT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 29 output +true + + +-- !query 30 +SELECT cast(1 as int) in (cast(1 as double)) FROM t +-- !query 30 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 30 output +true + + +-- !query 31 +SELECT cast(1 as int) in (cast(1 as decimal(10, 0))) FROM t +-- !query 31 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 31 output +true + + +-- !query 32 +SELECT cast(1 as int) in (cast(1 as string)) FROM t +-- !query 32 schema +struct<(CAST(CAST(1 AS INT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 32 output +true + + +-- !query 33 +SELECT cast(1 as int) in (cast('1' as binary)) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 + + +-- !query 34 +SELECT cast(1 as int) in (cast(1 as boolean)) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 + + +-- !query 35 +SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 + + +-- !query 36 +SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 + + +-- !query 37 +SELECT cast(1 as bigint) in (cast(1 as tinyint)) FROM t +-- !query 37 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS TINYINT) AS BIGINT))):boolean> +-- !query 37 output +true + + +-- !query 38 +SELECT cast(1 as bigint) in (cast(1 as smallint)) FROM t +-- !query 38 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS SMALLINT) AS BIGINT))):boolean> +-- !query 38 output +true + + +-- !query 39 +SELECT cast(1 as bigint) in (cast(1 as int)) FROM t +-- !query 39 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS INT) AS BIGINT))):boolean> +-- !query 39 output +true + + +-- !query 40 +SELECT cast(1 as bigint) in (cast(1 as bigint)) FROM t +-- !query 40 schema +struct<(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT))):boolean> +-- !query 40 output +true + + +-- !query 41 +SELECT cast(1 as bigint) in (cast(1 as float)) FROM t +-- !query 41 schema +struct<(CAST(CAST(1 AS BIGINT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 41 output +true + + +-- !query 42 +SELECT cast(1 as bigint) in (cast(1 as double)) FROM t +-- !query 42 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 42 output +true + + +-- !query 43 +SELECT cast(1 as bigint) in (cast(1 as decimal(10, 0))) FROM t +-- !query 43 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean> +-- !query 43 output +true + + +-- !query 44 +SELECT cast(1 as bigint) in (cast(1 as string)) FROM t +-- !query 44 schema +struct<(CAST(CAST(1 AS BIGINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 44 output +true + + +-- !query 45 +SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 + + +-- !query 46 +SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 + + +-- !query 47 +SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 + + +-- !query 48 +SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 + + +-- !query 49 +SELECT cast(1 as float) in (cast(1 as tinyint)) FROM t +-- !query 49 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS TINYINT) AS FLOAT))):boolean> +-- !query 49 output +true + + +-- !query 50 +SELECT cast(1 as float) in (cast(1 as smallint)) FROM t +-- !query 50 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS SMALLINT) AS FLOAT))):boolean> +-- !query 50 output +true + + +-- !query 51 +SELECT cast(1 as float) in (cast(1 as int)) FROM t +-- !query 51 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS INT) AS FLOAT))):boolean> +-- !query 51 output +true + + +-- !query 52 +SELECT cast(1 as float) in (cast(1 as bigint)) FROM t +-- !query 52 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS BIGINT) AS FLOAT))):boolean> +-- !query 52 output +true + + +-- !query 53 +SELECT cast(1 as float) in (cast(1 as float)) FROM t +-- !query 53 schema +struct<(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT))):boolean> +-- !query 53 output +true + + +-- !query 54 +SELECT cast(1 as float) in (cast(1 as double)) FROM t +-- !query 54 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 54 output +true + + +-- !query 55 +SELECT cast(1 as float) in (cast(1 as decimal(10, 0))) FROM t +-- !query 55 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 55 output +true + + +-- !query 56 +SELECT cast(1 as float) in (cast(1 as string)) FROM t +-- !query 56 schema +struct<(CAST(CAST(1 AS FLOAT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 56 output +false + + +-- !query 57 +SELECT cast(1 as float) in (cast('1' as binary)) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 + + +-- !query 58 +SELECT cast(1 as float) in (cast(1 as boolean)) FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 + + +-- !query 59 +SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 + + +-- !query 60 +SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 + + +-- !query 61 +SELECT cast(1 as double) in (cast(1 as tinyint)) FROM t +-- !query 61 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS TINYINT) AS DOUBLE))):boolean> +-- !query 61 output +true + + +-- !query 62 +SELECT cast(1 as double) in (cast(1 as smallint)) FROM t +-- !query 62 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS SMALLINT) AS DOUBLE))):boolean> +-- !query 62 output +true + + +-- !query 63 +SELECT cast(1 as double) in (cast(1 as int)) FROM t +-- !query 63 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS INT) AS DOUBLE))):boolean> +-- !query 63 output +true + + +-- !query 64 +SELECT cast(1 as double) in (cast(1 as bigint)) FROM t +-- !query 64 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS BIGINT) AS DOUBLE))):boolean> +-- !query 64 output +true + + +-- !query 65 +SELECT cast(1 as double) in (cast(1 as float)) FROM t +-- !query 65 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 65 output +true + + +-- !query 66 +SELECT cast(1 as double) in (cast(1 as double)) FROM t +-- !query 66 schema +struct<(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE))):boolean> +-- !query 66 output +true + + +-- !query 67 +SELECT cast(1 as double) in (cast(1 as decimal(10, 0))) FROM t +-- !query 67 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 67 output +true + + +-- !query 68 +SELECT cast(1 as double) in (cast(1 as string)) FROM t +-- !query 68 schema +struct<(CAST(CAST(1 AS DOUBLE) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 68 output +false + + +-- !query 69 +SELECT cast(1 as double) in (cast('1' as binary)) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 + + +-- !query 70 +SELECT cast(1 as double) in (cast(1 as boolean)) FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 + + +-- !query 71 +SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 + + +-- !query 72 +SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 + + +-- !query 73 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as tinyint)) FROM t +-- !query 73 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)))):boolean> +-- !query 73 output +true + + +-- !query 74 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as smallint)) FROM t +-- !query 74 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)))):boolean> +-- !query 74 output +true + + +-- !query 75 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as int)) FROM t +-- !query 75 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS INT) AS DECIMAL(10,0)))):boolean> +-- !query 75 output +true + + +-- !query 76 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as bigint)) FROM t +-- !query 76 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) IN (CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)))):boolean> +-- !query 76 output +true + + +-- !query 77 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as float)) FROM t +-- !query 77 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 77 output +true + + +-- !query 78 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as double)) FROM t +-- !query 78 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 78 output +true + + +-- !query 79 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0))) FROM t +-- !query 79 schema +struct<(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)))):boolean> +-- !query 79 output +true + + +-- !query 80 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as string)) FROM t +-- !query 80 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 80 output +true + + +-- !query 81 +SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 + + +-- !query 82 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 + + +-- !query 83 +SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 + + +-- !query 84 +SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 + + +-- !query 85 +SELECT cast(1 as string) in (cast(1 as tinyint)) FROM t +-- !query 85 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS TINYINT) AS STRING))):boolean> +-- !query 85 output +true + + +-- !query 86 +SELECT cast(1 as string) in (cast(1 as smallint)) FROM t +-- !query 86 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS SMALLINT) AS STRING))):boolean> +-- !query 86 output +true + + +-- !query 87 +SELECT cast(1 as string) in (cast(1 as int)) FROM t +-- !query 87 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS INT) AS STRING))):boolean> +-- !query 87 output +true + + +-- !query 88 +SELECT cast(1 as string) in (cast(1 as bigint)) FROM t +-- !query 88 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS BIGINT) AS STRING))):boolean> +-- !query 88 output +true + + +-- !query 89 +SELECT cast(1 as string) in (cast(1 as float)) FROM t +-- !query 89 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS FLOAT) AS STRING))):boolean> +-- !query 89 output +false + + +-- !query 90 +SELECT cast(1 as string) in (cast(1 as double)) FROM t +-- !query 90 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS DOUBLE) AS STRING))):boolean> +-- !query 90 output +false + + +-- !query 91 +SELECT cast(1 as string) in (cast(1 as decimal(10, 0))) FROM t +-- !query 91 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS STRING))):boolean> +-- !query 91 output +true + + +-- !query 92 +SELECT cast(1 as string) in (cast(1 as string)) FROM t +-- !query 92 schema +struct<(CAST(1 AS STRING) IN (CAST(1 AS STRING))):boolean> +-- !query 92 output +true + + +-- !query 93 +SELECT cast(1 as string) in (cast('1' as binary)) FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 + + +-- !query 94 +SELECT cast(1 as string) in (cast(1 as boolean)) FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 + + +-- !query 95 +SELECT cast(1 as string) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 95 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING))):boolean> +-- !query 95 output +false + + +-- !query 96 +SELECT cast(1 as string) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 96 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):boolean> +-- !query 96 output +false + + +-- !query 97 +SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t +-- !query 97 schema +struct<> +-- !query 97 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 + + +-- !query 98 +SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t +-- !query 98 schema +struct<> +-- !query 98 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 + + +-- !query 99 +SELECT cast('1' as binary) in (cast(1 as int)) FROM t +-- !query 99 schema +struct<> +-- !query 99 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 + + +-- !query 100 +SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t +-- !query 100 schema +struct<> +-- !query 100 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 + + +-- !query 101 +SELECT cast('1' as binary) in (cast(1 as float)) FROM t +-- !query 101 schema +struct<> +-- !query 101 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 + + +-- !query 102 +SELECT cast('1' as binary) in (cast(1 as double)) FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 + + +-- !query 103 +SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 + + +-- !query 104 +SELECT cast('1' as binary) in (cast(1 as string)) FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 + + +-- !query 105 +SELECT cast('1' as binary) in (cast('1' as binary)) FROM t +-- !query 105 schema +struct<(CAST(1 AS BINARY) IN (CAST(1 AS BINARY))):boolean> +-- !query 105 output +true + + +-- !query 106 +SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t +-- !query 106 schema +struct<> +-- !query 106 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 + + +-- !query 107 +SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 107 schema +struct<> +-- !query 107 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 + + +-- !query 108 +SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 108 schema +struct<> +-- !query 108 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 + + +-- !query 109 +SELECT true in (cast(1 as tinyint)) FROM t +-- !query 109 schema +struct<> +-- !query 109 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 12 + + +-- !query 110 +SELECT true in (cast(1 as smallint)) FROM t +-- !query 110 schema +struct<> +-- !query 110 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 12 + + +-- !query 111 +SELECT true in (cast(1 as int)) FROM t +-- !query 111 schema +struct<> +-- !query 111 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 12 + + +-- !query 112 +SELECT true in (cast(1 as bigint)) FROM t +-- !query 112 schema +struct<> +-- !query 112 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 12 + + +-- !query 113 +SELECT true in (cast(1 as float)) FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 12 + + +-- !query 114 +SELECT true in (cast(1 as double)) FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 12 + + +-- !query 115 +SELECT true in (cast(1 as decimal(10, 0))) FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 12 + + +-- !query 116 +SELECT true in (cast(1 as string)) FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 12 + + +-- !query 117 +SELECT true in (cast('1' as binary)) FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 12 + + +-- !query 118 +SELECT true in (cast(1 as boolean)) FROM t +-- !query 118 schema +struct<(true IN (CAST(1 AS BOOLEAN))):boolean> +-- !query 118 output +true + + +-- !query 119 +SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 12 + + +-- !query 120 +SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 12 + + +-- !query 121 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 + + +-- !query 122 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 + + +-- !query 123 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 + + +-- !query 124 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 + + +-- !query 125 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 + + +-- !query 126 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 + + +-- !query 127 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 + + +-- !query 128 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as string)) FROM t +-- !query 128 schema +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING) IN (CAST(CAST(2 AS STRING) AS STRING))):boolean> +-- !query 128 output +false + + +-- !query 129 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM t +-- !query 129 schema +struct<> +-- !query 129 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 + + +-- !query 130 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t +-- !query 130 schema +struct<> +-- !query 130 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 + + +-- !query 131 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 131 schema +struct<(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) IN (CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):boolean> +-- !query 131 output +false + + +-- !query 132 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 132 schema +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP) IN (CAST(CAST(2017-12-11 09:30:00 AS DATE) AS TIMESTAMP))):boolean> +-- !query 132 output +false + + +-- !query 133 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t +-- !query 133 schema +struct<> +-- !query 133 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 + + +-- !query 134 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t +-- !query 134 schema +struct<> +-- !query 134 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 + + +-- !query 135 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 + + +-- !query 136 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 + + +-- !query 137 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 + + +-- !query 138 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 + + +-- !query 139 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t +-- !query 139 schema +struct<> +-- !query 139 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 + + +-- !query 140 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as string)) FROM t +-- !query 140 schema +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING) IN (CAST(CAST(2 AS STRING) AS STRING))):boolean> +-- !query 140 output +false + + +-- !query 141 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t +-- !query 141 schema +struct<> +-- !query 141 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 + + +-- !query 142 +SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t +-- !query 142 schema +struct<> +-- !query 142 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 + + +-- !query 143 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 143 schema +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS TIMESTAMP) IN (CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP))):boolean> +-- !query 143 output +false + + +-- !query 144 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 144 schema +struct<(CAST(2017-12-12 09:30:00 AS DATE) IN (CAST(2017-12-11 09:30:00 AS DATE))):boolean> +-- !query 144 output +false + + +-- !query 145 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as tinyint)) FROM t +-- !query 145 schema +struct<(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS TINYINT))):boolean> +-- !query 145 output +true + + +-- !query 146 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as smallint)) FROM t +-- !query 146 schema +struct<(CAST(CAST(1 AS TINYINT) AS SMALLINT) IN (CAST(CAST(1 AS TINYINT) AS SMALLINT), CAST(CAST(1 AS SMALLINT) AS SMALLINT))):boolean> +-- !query 146 output +true + + +-- !query 147 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as int)) FROM t +-- !query 147 schema +struct<(CAST(CAST(1 AS TINYINT) AS INT) IN (CAST(CAST(1 AS TINYINT) AS INT), CAST(CAST(1 AS INT) AS INT))):boolean> +-- !query 147 output +true + + +-- !query 148 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as bigint)) FROM t +-- !query 148 schema +struct<(CAST(CAST(1 AS TINYINT) AS BIGINT) IN (CAST(CAST(1 AS TINYINT) AS BIGINT), CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 148 output +true + + +-- !query 149 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as float)) FROM t +-- !query 149 schema +struct<(CAST(CAST(1 AS TINYINT) AS FLOAT) IN (CAST(CAST(1 AS TINYINT) AS FLOAT), CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 149 output +true + + +-- !query 150 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as double)) FROM t +-- !query 150 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) IN (CAST(CAST(1 AS TINYINT) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 150 output +true + + +-- !query 151 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t +-- !query 151 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)), CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 151 output +true + + +-- !query 152 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as string)) FROM t +-- !query 152 schema +struct<(CAST(CAST(1 AS TINYINT) AS STRING) IN (CAST(CAST(1 AS TINYINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 152 output +true + + +-- !query 153 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t +-- !query 153 schema +struct<> +-- !query 153 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 + + +-- !query 154 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t +-- !query 154 schema +struct<> +-- !query 154 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 + + +-- !query 155 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 155 schema +struct<> +-- !query 155 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 + + +-- !query 156 +SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 156 schema +struct<> +-- !query 156 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 + + +-- !query 157 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as tinyint)) FROM t +-- !query 157 schema +struct<(CAST(CAST(1 AS SMALLINT) AS SMALLINT) IN (CAST(CAST(1 AS SMALLINT) AS SMALLINT), CAST(CAST(1 AS TINYINT) AS SMALLINT))):boolean> +-- !query 157 output +true + + +-- !query 158 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as smallint)) FROM t +-- !query 158 schema +struct<(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS SMALLINT))):boolean> +-- !query 158 output +true + + +-- !query 159 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as int)) FROM t +-- !query 159 schema +struct<(CAST(CAST(1 AS SMALLINT) AS INT) IN (CAST(CAST(1 AS SMALLINT) AS INT), CAST(CAST(1 AS INT) AS INT))):boolean> +-- !query 159 output +true + + +-- !query 160 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as bigint)) FROM t +-- !query 160 schema +struct<(CAST(CAST(1 AS SMALLINT) AS BIGINT) IN (CAST(CAST(1 AS SMALLINT) AS BIGINT), CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 160 output +true + + +-- !query 161 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as float)) FROM t +-- !query 161 schema +struct<(CAST(CAST(1 AS SMALLINT) AS FLOAT) IN (CAST(CAST(1 AS SMALLINT) AS FLOAT), CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 161 output +true + + +-- !query 162 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as double)) FROM t +-- !query 162 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) IN (CAST(CAST(1 AS SMALLINT) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 162 output +true + + +-- !query 163 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t +-- !query 163 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)), CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 163 output +true + + +-- !query 164 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as string)) FROM t +-- !query 164 schema +struct<(CAST(CAST(1 AS SMALLINT) AS STRING) IN (CAST(CAST(1 AS SMALLINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 164 output +true + + +-- !query 165 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t +-- !query 165 schema +struct<> +-- !query 165 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 + + +-- !query 166 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t +-- !query 166 schema +struct<> +-- !query 166 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 + + +-- !query 167 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 167 schema +struct<> +-- !query 167 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 + + +-- !query 168 +SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 168 schema +struct<> +-- !query 168 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 + + +-- !query 169 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as tinyint)) FROM t +-- !query 169 schema +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS INT) AS INT), CAST(CAST(1 AS TINYINT) AS INT))):boolean> +-- !query 169 output +true + + +-- !query 170 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as smallint)) FROM t +-- !query 170 schema +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS INT) AS INT), CAST(CAST(1 AS SMALLINT) AS INT))):boolean> +-- !query 170 output +true + + +-- !query 171 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as int)) FROM t +-- !query 171 schema +struct<(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS INT))):boolean> +-- !query 171 output +true + + +-- !query 172 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as bigint)) FROM t +-- !query 172 schema +struct<(CAST(CAST(1 AS INT) AS BIGINT) IN (CAST(CAST(1 AS INT) AS BIGINT), CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> +-- !query 172 output +true + + +-- !query 173 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as float)) FROM t +-- !query 173 schema +struct<(CAST(CAST(1 AS INT) AS FLOAT) IN (CAST(CAST(1 AS INT) AS FLOAT), CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 173 output +true + + +-- !query 174 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as double)) FROM t +-- !query 174 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) IN (CAST(CAST(1 AS INT) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 174 output +true + + +-- !query 175 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as decimal(10, 0))) FROM t +-- !query 175 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS INT) AS DECIMAL(10,0)), CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 175 output +true + + +-- !query 176 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as string)) FROM t +-- !query 176 schema +struct<(CAST(CAST(1 AS INT) AS STRING) IN (CAST(CAST(1 AS INT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 176 output +true + + +-- !query 177 +SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t +-- !query 177 schema +struct<> +-- !query 177 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 + + +-- !query 178 +SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t +-- !query 178 schema +struct<> +-- !query 178 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 + + +-- !query 179 +SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 179 schema +struct<> +-- !query 179 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 + + +-- !query 180 +SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 180 schema +struct<> +-- !query 180 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 + + +-- !query 181 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as tinyint)) FROM t +-- !query 181 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT), CAST(CAST(1 AS TINYINT) AS BIGINT))):boolean> +-- !query 181 output +true + + +-- !query 182 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as smallint)) FROM t +-- !query 182 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT), CAST(CAST(1 AS SMALLINT) AS BIGINT))):boolean> +-- !query 182 output +true + + +-- !query 183 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as int)) FROM t +-- !query 183 schema +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT), CAST(CAST(1 AS INT) AS BIGINT))):boolean> +-- !query 183 output +true + + +-- !query 184 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as bigint)) FROM t +-- !query 184 schema +struct<(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BIGINT))):boolean> +-- !query 184 output +true + + +-- !query 185 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as float)) FROM t +-- !query 185 schema +struct<(CAST(CAST(1 AS BIGINT) AS FLOAT) IN (CAST(CAST(1 AS BIGINT) AS FLOAT), CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> +-- !query 185 output +true + + +-- !query 186 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as double)) FROM t +-- !query 186 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) IN (CAST(CAST(1 AS BIGINT) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 186 output +true + + +-- !query 187 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t +-- !query 187 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) IN (CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)), CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean> +-- !query 187 output +true + + +-- !query 188 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as string)) FROM t +-- !query 188 schema +struct<(CAST(CAST(1 AS BIGINT) AS STRING) IN (CAST(CAST(1 AS BIGINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 188 output +true + + +-- !query 189 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t +-- !query 189 schema +struct<> +-- !query 189 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 + + +-- !query 190 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t +-- !query 190 schema +struct<> +-- !query 190 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 + + +-- !query 191 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 191 schema +struct<> +-- !query 191 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 + + +-- !query 192 +SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 192 schema +struct<> +-- !query 192 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 + + +-- !query 193 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as tinyint)) FROM t +-- !query 193 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT), CAST(CAST(1 AS TINYINT) AS FLOAT))):boolean> +-- !query 193 output +true + + +-- !query 194 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as smallint)) FROM t +-- !query 194 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT), CAST(CAST(1 AS SMALLINT) AS FLOAT))):boolean> +-- !query 194 output +true + + +-- !query 195 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as int)) FROM t +-- !query 195 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT), CAST(CAST(1 AS INT) AS FLOAT))):boolean> +-- !query 195 output +true + + +-- !query 196 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as bigint)) FROM t +-- !query 196 schema +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT), CAST(CAST(1 AS BIGINT) AS FLOAT))):boolean> +-- !query 196 output +true + + +-- !query 197 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as float)) FROM t +-- !query 197 schema +struct<(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS FLOAT))):boolean> +-- !query 197 output +true + + +-- !query 198 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as double)) FROM t +-- !query 198 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) IN (CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 198 output +true + + +-- !query 199 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as decimal(10, 0))) FROM t +-- !query 199 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) IN (CAST(CAST(1 AS FLOAT) AS DOUBLE), CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 199 output +true + + +-- !query 200 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as string)) FROM t +-- !query 200 schema +struct<(CAST(CAST(1 AS FLOAT) AS STRING) IN (CAST(CAST(1 AS FLOAT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 200 output +true + + +-- !query 201 +SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t +-- !query 201 schema +struct<> +-- !query 201 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 + + +-- !query 202 +SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t +-- !query 202 schema +struct<> +-- !query 202 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 + + +-- !query 203 +SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 203 schema +struct<> +-- !query 203 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 + + +-- !query 204 +SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 204 schema +struct<> +-- !query 204 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 + + +-- !query 205 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as tinyint)) FROM t +-- !query 205 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS TINYINT) AS DOUBLE))):boolean> +-- !query 205 output +true + + +-- !query 206 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as smallint)) FROM t +-- !query 206 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS SMALLINT) AS DOUBLE))):boolean> +-- !query 206 output +true + + +-- !query 207 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as int)) FROM t +-- !query 207 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS INT) AS DOUBLE))):boolean> +-- !query 207 output +true + + +-- !query 208 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as bigint)) FROM t +-- !query 208 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS BIGINT) AS DOUBLE))):boolean> +-- !query 208 output +true + + +-- !query 209 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as float)) FROM t +-- !query 209 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 209 output +true + + +-- !query 210 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as double)) FROM t +-- !query 210 schema +struct<(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS DOUBLE))):boolean> +-- !query 210 output +true + + +-- !query 211 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as decimal(10, 0))) FROM t +-- !query 211 schema +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 211 output +true + + +-- !query 212 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as string)) FROM t +-- !query 212 schema +struct<(CAST(CAST(1 AS DOUBLE) AS STRING) IN (CAST(CAST(1 AS DOUBLE) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 212 output +true + + +-- !query 213 +SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t +-- !query 213 schema +struct<> +-- !query 213 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 + + +-- !query 214 +SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t +-- !query 214 schema +struct<> +-- !query 214 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 + + +-- !query 215 +SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 215 schema +struct<> +-- !query 215 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 + + +-- !query 216 +SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 216 schema +struct<> +-- !query 216 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 + + +-- !query 217 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t +-- !query 217 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)), CAST(CAST(1 AS TINYINT) AS DECIMAL(10,0)))):boolean> +-- !query 217 output +true + + +-- !query 218 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t +-- !query 218 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)), CAST(CAST(1 AS SMALLINT) AS DECIMAL(10,0)))):boolean> +-- !query 218 output +true + + +-- !query 219 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as int)) FROM t +-- !query 219 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)), CAST(CAST(1 AS INT) AS DECIMAL(10,0)))):boolean> +-- !query 219 output +true + + +-- !query 220 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t +-- !query 220 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)), CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)))):boolean> +-- !query 220 output +true + + +-- !query 221 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as float)) FROM t +-- !query 221 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 221 output +true + + +-- !query 222 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as double)) FROM t +-- !query 222 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> +-- !query 222 output +true + + +-- !query 223 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t +-- !query 223 schema +struct<(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS DECIMAL(10,0)))):boolean> +-- !query 223 output +true + + +-- !query 224 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as string)) FROM t +-- !query 224 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS STRING) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 224 output +true + + +-- !query 225 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t +-- !query 225 schema +struct<> +-- !query 225 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 + + +-- !query 226 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t +-- !query 226 schema +struct<> +-- !query 226 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 + + +-- !query 227 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 227 schema +struct<> +-- !query 227 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 + + +-- !query 228 +SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 228 schema +struct<> +-- !query 228 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 + + +-- !query 229 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as tinyint)) FROM t +-- !query 229 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS TINYINT) AS STRING))):boolean> +-- !query 229 output +true + + +-- !query 230 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as smallint)) FROM t +-- !query 230 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS SMALLINT) AS STRING))):boolean> +-- !query 230 output +true + + +-- !query 231 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as int)) FROM t +-- !query 231 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS INT) AS STRING))):boolean> +-- !query 231 output +true + + +-- !query 232 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as bigint)) FROM t +-- !query 232 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS BIGINT) AS STRING))):boolean> +-- !query 232 output +true + + +-- !query 233 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as float)) FROM t +-- !query 233 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS FLOAT) AS STRING))):boolean> +-- !query 233 output +true + + +-- !query 234 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as double)) FROM t +-- !query 234 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS DOUBLE) AS STRING))):boolean> +-- !query 234 output +true + + +-- !query 235 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as decimal(10, 0))) FROM t +-- !query 235 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(1 AS DECIMAL(10,0)) AS STRING))):boolean> +-- !query 235 output +true + + +-- !query 236 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as string)) FROM t +-- !query 236 schema +struct<(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS STRING))):boolean> +-- !query 236 output +true + + +-- !query 237 +SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t +-- !query 237 schema +struct<> +-- !query 237 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 + + +-- !query 238 +SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t +-- !query 238 schema +struct<> +-- !query 238 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 + + +-- !query 239 +SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 239 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING))):boolean> +-- !query 239 output +true + + +-- !query 240 +SELECT cast(1 as string) in (cast(1 as string), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 240 schema +struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):boolean> +-- !query 240 output +true + + +-- !query 241 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t +-- !query 241 schema +struct<> +-- !query 241 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 + + +-- !query 242 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t +-- !query 242 schema +struct<> +-- !query 242 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 + + +-- !query 243 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t +-- !query 243 schema +struct<> +-- !query 243 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 + + +-- !query 244 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t +-- !query 244 schema +struct<> +-- !query 244 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 + + +-- !query 245 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t +-- !query 245 schema +struct<> +-- !query 245 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 + + +-- !query 246 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t +-- !query 246 schema +struct<> +-- !query 246 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 + + +-- !query 247 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) FROM t +-- !query 247 schema +struct<> +-- !query 247 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 + + +-- !query 248 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t +-- !query 248 schema +struct<> +-- !query 248 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 + + +-- !query 249 +SELECT cast('1' as binary) in (cast('1' as binary), cast('1' as binary)) FROM t +-- !query 249 schema +struct<(CAST(1 AS BINARY) IN (CAST(1 AS BINARY), CAST(1 AS BINARY))):boolean> +-- !query 249 output +true + + +-- !query 250 +SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t +-- !query 250 schema +struct<> +-- !query 250 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 + + +-- !query 251 +SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 251 schema +struct<> +-- !query 251 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 + + +-- !query 252 +SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 252 schema +struct<> +-- !query 252 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 + + +-- !query 253 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t +-- !query 253 schema +struct<> +-- !query 253 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 28 + + +-- !query 254 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM t +-- !query 254 schema +struct<> +-- !query 254 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 28 + + +-- !query 255 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t +-- !query 255 schema +struct<> +-- !query 255 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 28 + + +-- !query 256 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t +-- !query 256 schema +struct<> +-- !query 256 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 28 + + +-- !query 257 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t +-- !query 257 schema +struct<> +-- !query 257 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 28 + + +-- !query 258 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t +-- !query 258 schema +struct<> +-- !query 258 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 28 + + +-- !query 259 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) FROM t +-- !query 259 schema +struct<> +-- !query 259 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 28 + + +-- !query 260 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t +-- !query 260 schema +struct<> +-- !query 260 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 28 + + +-- !query 261 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM t +-- !query 261 schema +struct<> +-- !query 261 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 28 + + +-- !query 262 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as boolean)) FROM t +-- !query 262 schema +struct<(CAST(1 AS BOOLEAN) IN (CAST(1 AS BOOLEAN), CAST(1 AS BOOLEAN))):boolean> +-- !query 262 output +true + + +-- !query 263 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 263 schema +struct<> +-- !query 263 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 28 + + +-- !query 264 +SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 264 schema +struct<> +-- !query 264 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 28 + + +-- !query 265 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as tinyint)) FROM t +-- !query 265 schema +struct<> +-- !query 265 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 + + +-- !query 266 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as smallint)) FROM t +-- !query 266 schema +struct<> +-- !query 266 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 + + +-- !query 267 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as int)) FROM t +-- !query 267 schema +struct<> +-- !query 267 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 + + +-- !query 268 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as bigint)) FROM t +-- !query 268 schema +struct<> +-- !query 268 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 + + +-- !query 269 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as float)) FROM t +-- !query 269 schema +struct<> +-- !query 269 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 + + +-- !query 270 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as double)) FROM t +-- !query 270 schema +struct<> +-- !query 270 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 + + +-- !query 271 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t +-- !query 271 schema +struct<> +-- !query 271 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 + + +-- !query 272 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as string)) FROM t +-- !query 272 schema +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING) IN (CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 272 output +true + + +-- !query 273 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('1' as binary)) FROM t +-- !query 273 schema +struct<> +-- !query 273 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 + + +-- !query 274 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as boolean)) FROM t +-- !query 274 schema +struct<> +-- !query 274 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 + + +-- !query 275 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 275 schema +struct<(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) IN (CAST(2017-12-12 09:30:00.0 AS TIMESTAMP), CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):boolean> +-- !query 275 output +true + + +-- !query 276 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 276 schema +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP) IN (CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP), CAST(CAST(2017-12-11 09:30:00 AS DATE) AS TIMESTAMP))):boolean> +-- !query 276 output +true + + +-- !query 277 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as tinyint)) FROM t +-- !query 277 schema +struct<> +-- !query 277 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 + + +-- !query 278 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as smallint)) FROM t +-- !query 278 schema +struct<> +-- !query 278 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 + + +-- !query 279 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as int)) FROM t +-- !query 279 schema +struct<> +-- !query 279 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 + + +-- !query 280 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as bigint)) FROM t +-- !query 280 schema +struct<> +-- !query 280 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 + + +-- !query 281 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as float)) FROM t +-- !query 281 schema +struct<> +-- !query 281 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 + + +-- !query 282 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as double)) FROM t +-- !query 282 schema +struct<> +-- !query 282 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 + + +-- !query 283 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t +-- !query 283 schema +struct<> +-- !query 283 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 + + +-- !query 284 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as string)) FROM t +-- !query 284 schema +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING) IN (CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +-- !query 284 output +true + + +-- !query 285 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('1' as binary)) FROM t +-- !query 285 schema +struct<> +-- !query 285 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 + + +-- !query 286 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as boolean)) FROM t +-- !query 286 schema +struct<> +-- !query 286 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 + + +-- !query 287 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 287 schema +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS TIMESTAMP) IN (CAST(CAST(2017-12-12 09:30:00 AS DATE) AS TIMESTAMP), CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP))):boolean> +-- !query 287 output +true + + +-- !query 288 +SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 288 schema +struct<(CAST(2017-12-12 09:30:00 AS DATE) IN (CAST(2017-12-12 09:30:00 AS DATE), CAST(2017-12-11 09:30:00 AS DATE))):boolean> +-- !query 288 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/promoteStrings.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/promoteStrings.sql.out new file mode 100644 index 0000000000000..0beb1f6263d2c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/promoteStrings.sql.out @@ -0,0 +1,2578 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 316 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT '1' + cast(1 as tinyint) FROM t +-- !query 1 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 1 output +2.0 + + +-- !query 2 +SELECT '1' + cast(1 as smallint) FROM t +-- !query 2 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 2 output +2.0 + + +-- !query 3 +SELECT '1' + cast(1 as int) FROM t +-- !query 3 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 3 output +2.0 + + +-- !query 4 +SELECT '1' + cast(1 as bigint) FROM t +-- !query 4 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 4 output +2.0 + + +-- !query 5 +SELECT '1' + cast(1 as float) FROM t +-- !query 5 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 5 output +2.0 + + +-- !query 6 +SELECT '1' + cast(1 as double) FROM t +-- !query 6 schema +struct<(CAST(1 AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 6 output +2.0 + + +-- !query 7 +SELECT '1' + cast(1 as decimal(10, 0)) FROM t +-- !query 7 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 7 output +2.0 + + +-- !query 8 +SELECT '1' + '1' FROM t +-- !query 8 schema +struct<(CAST(1 AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 8 output +2.0 + + +-- !query 9 +SELECT '1' + cast('1' as binary) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) + CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 10 +SELECT '1' + cast(1 as boolean) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) + CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 11 +SELECT '1' + cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 12 +SELECT '1' + cast('2017-12-11 09:30:00' as date) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) + CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 13 +SELECT '1' - cast(1 as tinyint) FROM t +-- !query 13 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 13 output +0.0 + + +-- !query 14 +SELECT '1' - cast(1 as smallint) FROM t +-- !query 14 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 14 output +0.0 + + +-- !query 15 +SELECT '1' - cast(1 as int) FROM t +-- !query 15 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 15 output +0.0 + + +-- !query 16 +SELECT '1' - cast(1 as bigint) FROM t +-- !query 16 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 16 output +0.0 + + +-- !query 17 +SELECT '1' - cast(1 as float) FROM t +-- !query 17 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 17 output +0.0 + + +-- !query 18 +SELECT '1' - cast(1 as double) FROM t +-- !query 18 schema +struct<(CAST(1 AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 18 output +0.0 + + +-- !query 19 +SELECT '1' - cast(1 as decimal(10, 0)) FROM t +-- !query 19 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 19 output +0.0 + + +-- !query 20 +SELECT '1' - '1' FROM t +-- !query 20 schema +struct<(CAST(1 AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 20 output +0.0 + + +-- !query 21 +SELECT '1' - cast('1' as binary) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) - CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 22 +SELECT '1' - cast(1 as boolean) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) - CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 23 +SELECT '1' - cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 24 +SELECT '1' - cast('2017-12-11 09:30:00' as date) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) - CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 25 +SELECT '1' * cast(1 as tinyint) FROM t +-- !query 25 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 25 output +1.0 + + +-- !query 26 +SELECT '1' * cast(1 as smallint) FROM t +-- !query 26 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 26 output +1.0 + + +-- !query 27 +SELECT '1' * cast(1 as int) FROM t +-- !query 27 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 27 output +1.0 + + +-- !query 28 +SELECT '1' * cast(1 as bigint) FROM t +-- !query 28 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 28 output +1.0 + + +-- !query 29 +SELECT '1' * cast(1 as float) FROM t +-- !query 29 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 29 output +1.0 + + +-- !query 30 +SELECT '1' * cast(1 as double) FROM t +-- !query 30 schema +struct<(CAST(1 AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 30 output +1.0 + + +-- !query 31 +SELECT '1' * cast(1 as decimal(10, 0)) FROM t +-- !query 31 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 31 output +1.0 + + +-- !query 32 +SELECT '1' * '1' FROM t +-- !query 32 schema +struct<(CAST(1 AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 32 output +1.0 + + +-- !query 33 +SELECT '1' * cast('1' as binary) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) * CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 34 +SELECT '1' * cast(1 as boolean) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) * CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 35 +SELECT '1' * cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) * CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) * CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 36 +SELECT '1' * cast('2017-12-11 09:30:00' as date) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) * CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) * CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 37 +SELECT '1' / cast(1 as tinyint) FROM t +-- !query 37 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 37 output +1.0 + + +-- !query 38 +SELECT '1' / cast(1 as smallint) FROM t +-- !query 38 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 38 output +1.0 + + +-- !query 39 +SELECT '1' / cast(1 as int) FROM t +-- !query 39 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 39 output +1.0 + + +-- !query 40 +SELECT '1' / cast(1 as bigint) FROM t +-- !query 40 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 40 output +1.0 + + +-- !query 41 +SELECT '1' / cast(1 as float) FROM t +-- !query 41 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 41 output +1.0 + + +-- !query 42 +SELECT '1' / cast(1 as double) FROM t +-- !query 42 schema +struct<(CAST(1 AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 42 output +1.0 + + +-- !query 43 +SELECT '1' / cast(1 as decimal(10, 0)) FROM t +-- !query 43 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 43 output +1.0 + + +-- !query 44 +SELECT '1' / '1' FROM t +-- !query 44 schema +struct<(CAST(1 AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 44 output +1.0 + + +-- !query 45 +SELECT '1' / cast('1' as binary) FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) / CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 46 +SELECT '1' / cast(1 as boolean) FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) / CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 47 +SELECT '1' / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 48 +SELECT '1' / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 49 +SELECT '1' % cast(1 as tinyint) FROM t +-- !query 49 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 49 output +0.0 + + +-- !query 50 +SELECT '1' % cast(1 as smallint) FROM t +-- !query 50 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 50 output +0.0 + + +-- !query 51 +SELECT '1' % cast(1 as int) FROM t +-- !query 51 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 51 output +0.0 + + +-- !query 52 +SELECT '1' % cast(1 as bigint) FROM t +-- !query 52 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 52 output +0.0 + + +-- !query 53 +SELECT '1' % cast(1 as float) FROM t +-- !query 53 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 53 output +0.0 + + +-- !query 54 +SELECT '1' % cast(1 as double) FROM t +-- !query 54 schema +struct<(CAST(1 AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 54 output +0.0 + + +-- !query 55 +SELECT '1' % cast(1 as decimal(10, 0)) FROM t +-- !query 55 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 55 output +0.0 + + +-- !query 56 +SELECT '1' % '1' FROM t +-- !query 56 schema +struct<(CAST(1 AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 56 output +0.0 + + +-- !query 57 +SELECT '1' % cast('1' as binary) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) % CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 58 +SELECT '1' % cast(1 as boolean) FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) % CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 59 +SELECT '1' % cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 60 +SELECT '1' % cast('2017-12-11 09:30:00' as date) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS DOUBLE) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS DOUBLE) % CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 61 +SELECT pmod('1', cast(1 as tinyint)) FROM t +-- !query 61 schema +struct +-- !query 61 output +0.0 + + +-- !query 62 +SELECT pmod('1', cast(1 as smallint)) FROM t +-- !query 62 schema +struct +-- !query 62 output +0.0 + + +-- !query 63 +SELECT pmod('1', cast(1 as int)) FROM t +-- !query 63 schema +struct +-- !query 63 output +0.0 + + +-- !query 64 +SELECT pmod('1', cast(1 as bigint)) FROM t +-- !query 64 schema +struct +-- !query 64 output +0.0 + + +-- !query 65 +SELECT pmod('1', cast(1 as float)) FROM t +-- !query 65 schema +struct +-- !query 65 output +0.0 + + +-- !query 66 +SELECT pmod('1', cast(1 as double)) FROM t +-- !query 66 schema +struct +-- !query 66 output +0.0 + + +-- !query 67 +SELECT pmod('1', cast(1 as decimal(10, 0))) FROM t +-- !query 67 schema +struct +-- !query 67 output +0.0 + + +-- !query 68 +SELECT pmod('1', '1') FROM t +-- !query 68 schema +struct +-- !query 68 output +0.0 + + +-- !query 69 +SELECT pmod('1', cast('1' as binary)) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS DOUBLE), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST('1' AS DOUBLE), CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 70 +SELECT pmod('1', cast(1 as boolean)) FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS DOUBLE), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST('1' AS DOUBLE), CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 71 +SELECT pmod('1', cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST('1' AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 72 +SELECT pmod('1', cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST('1' AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 73 +SELECT cast(1 as tinyint) + '1' FROM t +-- !query 73 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 73 output +2.0 + + +-- !query 74 +SELECT cast(1 as smallint) + '1' FROM t +-- !query 74 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 74 output +2.0 + + +-- !query 75 +SELECT cast(1 as int) + '1' FROM t +-- !query 75 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 75 output +2.0 + + +-- !query 76 +SELECT cast(1 as bigint) + '1' FROM t +-- !query 76 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 76 output +2.0 + + +-- !query 77 +SELECT cast(1 as float) + '1' FROM t +-- !query 77 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 77 output +2.0 + + +-- !query 78 +SELECT cast(1 as double) + '1' FROM t +-- !query 78 schema +struct<(CAST(1 AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 78 output +2.0 + + +-- !query 79 +SELECT cast(1 as decimal(10, 0)) + '1' FROM t +-- !query 79 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 79 output +2.0 + + +-- !query 80 +SELECT cast('1' as binary) + '1' FROM t +-- !query 80 schema +struct<> +-- !query 80 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 81 +SELECT cast(1 as boolean) + '1' FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) + CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) + CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 82 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + '1' FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 83 +SELECT cast('2017-12-11 09:30:00' as date) + '1' FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 84 +SELECT cast(1 as tinyint) - '1' FROM t +-- !query 84 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 84 output +0.0 + + +-- !query 85 +SELECT cast(1 as smallint) - '1' FROM t +-- !query 85 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 85 output +0.0 + + +-- !query 86 +SELECT cast(1 as int) - '1' FROM t +-- !query 86 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 86 output +0.0 + + +-- !query 87 +SELECT cast(1 as bigint) - '1' FROM t +-- !query 87 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 87 output +0.0 + + +-- !query 88 +SELECT cast(1 as float) - '1' FROM t +-- !query 88 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 88 output +0.0 + + +-- !query 89 +SELECT cast(1 as double) - '1' FROM t +-- !query 89 schema +struct<(CAST(1 AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 89 output +0.0 + + +-- !query 90 +SELECT cast(1 as decimal(10, 0)) - '1' FROM t +-- !query 90 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 90 output +0.0 + + +-- !query 91 +SELECT cast('1' as binary) - '1' FROM t +-- !query 91 schema +struct<> +-- !query 91 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 92 +SELECT cast(1 as boolean) - '1' FROM t +-- !query 92 schema +struct<> +-- !query 92 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) - CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) - CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 93 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - '1' FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 94 +SELECT cast('2017-12-11 09:30:00' as date) - '1' FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 95 +SELECT cast(1 as tinyint) * '1' FROM t +-- !query 95 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 95 output +1.0 + + +-- !query 96 +SELECT cast(1 as smallint) * '1' FROM t +-- !query 96 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 96 output +1.0 + + +-- !query 97 +SELECT cast(1 as int) * '1' FROM t +-- !query 97 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 97 output +1.0 + + +-- !query 98 +SELECT cast(1 as bigint) * '1' FROM t +-- !query 98 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 98 output +1.0 + + +-- !query 99 +SELECT cast(1 as float) * '1' FROM t +-- !query 99 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 99 output +1.0 + + +-- !query 100 +SELECT cast(1 as double) * '1' FROM t +-- !query 100 schema +struct<(CAST(1 AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 100 output +1.0 + + +-- !query 101 +SELECT cast(1 as decimal(10, 0)) * '1' FROM t +-- !query 101 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 101 output +1.0 + + +-- !query 102 +SELECT cast('1' as binary) * '1' FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) * CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 103 +SELECT cast(1 as boolean) * '1' FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) * CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) * CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 104 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) * '1' FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) * CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) * CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 105 +SELECT cast('2017-12-11 09:30:00' as date) * '1' FROM t +-- !query 105 schema +struct<> +-- !query 105 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) * CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) * CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 106 +SELECT cast(1 as tinyint) / '1' FROM t +-- !query 106 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 106 output +1.0 + + +-- !query 107 +SELECT cast(1 as smallint) / '1' FROM t +-- !query 107 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 107 output +1.0 + + +-- !query 108 +SELECT cast(1 as int) / '1' FROM t +-- !query 108 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 108 output +1.0 + + +-- !query 109 +SELECT cast(1 as bigint) / '1' FROM t +-- !query 109 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 109 output +1.0 + + +-- !query 110 +SELECT cast(1 as float) / '1' FROM t +-- !query 110 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 110 output +1.0 + + +-- !query 111 +SELECT cast(1 as double) / '1' FROM t +-- !query 111 schema +struct<(CAST(1 AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 111 output +1.0 + + +-- !query 112 +SELECT cast(1 as decimal(10, 0)) / '1' FROM t +-- !query 112 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 112 output +1.0 + + +-- !query 113 +SELECT cast('1' as binary) / '1' FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 114 +SELECT cast(1 as boolean) / '1' FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 115 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / '1' FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 116 +SELECT cast('2017-12-11 09:30:00' as date) / '1' FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 117 +SELECT cast(1 as tinyint) % '1' FROM t +-- !query 117 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 117 output +0.0 + + +-- !query 118 +SELECT cast(1 as smallint) % '1' FROM t +-- !query 118 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 118 output +0.0 + + +-- !query 119 +SELECT cast(1 as int) % '1' FROM t +-- !query 119 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 119 output +0.0 + + +-- !query 120 +SELECT cast(1 as bigint) % '1' FROM t +-- !query 120 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 120 output +0.0 + + +-- !query 121 +SELECT cast(1 as float) % '1' FROM t +-- !query 121 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 121 output +0.0 + + +-- !query 122 +SELECT cast(1 as double) % '1' FROM t +-- !query 122 schema +struct<(CAST(1 AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 122 output +0.0 + + +-- !query 123 +SELECT cast(1 as decimal(10, 0)) % '1' FROM t +-- !query 123 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 123 output +0.0 + + +-- !query 124 +SELECT cast('1' as binary) % '1' FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) % CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 125 +SELECT cast(1 as boolean) % '1' FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) % CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) % CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 126 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % '1' FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 127 +SELECT cast('2017-12-11 09:30:00' as date) % '1' FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST('1' AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 128 +SELECT pmod(cast(1 as tinyint), '1') FROM t +-- !query 128 schema +struct +-- !query 128 output +0.0 + + +-- !query 129 +SELECT pmod(cast(1 as smallint), '1') FROM t +-- !query 129 schema +struct +-- !query 129 output +0.0 + + +-- !query 130 +SELECT pmod(cast(1 as int), '1') FROM t +-- !query 130 schema +struct +-- !query 130 output +0.0 + + +-- !query 131 +SELECT pmod(cast(1 as bigint), '1') FROM t +-- !query 131 schema +struct +-- !query 131 output +0.0 + + +-- !query 132 +SELECT pmod(cast(1 as float), '1') FROM t +-- !query 132 schema +struct +-- !query 132 output +0.0 + + +-- !query 133 +SELECT pmod(cast(1 as double), '1') FROM t +-- !query 133 schema +struct +-- !query 133 output +0.0 + + +-- !query 134 +SELECT pmod(cast(1 as decimal(10, 0)), '1') FROM t +-- !query 134 schema +struct +-- !query 134 output +0.0 + + +-- !query 135 +SELECT pmod(cast('1' as binary), '1') FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS BINARY), CAST('1' AS DOUBLE))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST('1' AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 136 +SELECT pmod(cast(1 as boolean), '1') FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS BOOLEAN), CAST('1' AS DOUBLE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS BOOLEAN), CAST('1' AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 137 +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), '1') FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST('1' AS DOUBLE))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST('1' AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 138 +SELECT pmod(cast('2017-12-11 09:30:00' as date), '1') FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST('1' AS DOUBLE))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST('1' AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 139 +SELECT '1' = cast(1 as tinyint) FROM t +-- !query 139 schema +struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 139 output +true + + +-- !query 140 +SELECT '1' = cast(1 as smallint) FROM t +-- !query 140 schema +struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 140 output +true + + +-- !query 141 +SELECT '1' = cast(1 as int) FROM t +-- !query 141 schema +struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean> +-- !query 141 output +true + + +-- !query 142 +SELECT '1' = cast(1 as bigint) FROM t +-- !query 142 schema +struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 142 output +true + + +-- !query 143 +SELECT '1' = cast(1 as float) FROM t +-- !query 143 schema +struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 143 output +true + + +-- !query 144 +SELECT '1' = cast(1 as double) FROM t +-- !query 144 schema +struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 144 output +true + + +-- !query 145 +SELECT '1' = cast(1 as decimal(10, 0)) FROM t +-- !query 145 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 145 output +true + + +-- !query 146 +SELECT '1' = '1' FROM t +-- !query 146 schema +struct<(1 = 1):boolean> +-- !query 146 output +true + + +-- !query 147 +SELECT '1' = cast('1' as binary) FROM t +-- !query 147 schema +struct<(CAST(1 AS BINARY) = CAST(1 AS BINARY)):boolean> +-- !query 147 output +true + + +-- !query 148 +SELECT '1' = cast(1 as boolean) FROM t +-- !query 148 schema +struct<(CAST(1 AS BOOLEAN) = CAST(1 AS BOOLEAN)):boolean> +-- !query 148 output +true + + +-- !query 149 +SELECT '1' = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 149 schema +struct<(CAST(1 AS TIMESTAMP) = CAST(2017-12-11 09:30:00.0 AS TIMESTAMP)):boolean> +-- !query 149 output +NULL + + +-- !query 150 +SELECT '1' = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 150 schema +struct<(1 = CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 150 output +false + + +-- !query 151 +SELECT cast(1 as tinyint) = '1' FROM t +-- !query 151 schema +struct<(CAST(1 AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 151 output +true + + +-- !query 152 +SELECT cast(1 as smallint) = '1' FROM t +-- !query 152 schema +struct<(CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 152 output +true + + +-- !query 153 +SELECT cast(1 as int) = '1' FROM t +-- !query 153 schema +struct<(CAST(1 AS INT) = CAST(1 AS INT)):boolean> +-- !query 153 output +true + + +-- !query 154 +SELECT cast(1 as bigint) = '1' FROM t +-- !query 154 schema +struct<(CAST(1 AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 154 output +true + + +-- !query 155 +SELECT cast(1 as float) = '1' FROM t +-- !query 155 schema +struct<(CAST(1 AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 155 output +true + + +-- !query 156 +SELECT cast(1 as double) = '1' FROM t +-- !query 156 schema +struct<(CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 156 output +true + + +-- !query 157 +SELECT cast(1 as decimal(10, 0)) = '1' FROM t +-- !query 157 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 157 output +true + + +-- !query 158 +SELECT cast('1' as binary) = '1' FROM t +-- !query 158 schema +struct<(CAST(1 AS BINARY) = CAST(1 AS BINARY)):boolean> +-- !query 158 output +true + + +-- !query 159 +SELECT cast(1 as boolean) = '1' FROM t +-- !query 159 schema +struct<(CAST(1 AS BOOLEAN) = CAST(1 AS BOOLEAN)):boolean> +-- !query 159 output +true + + +-- !query 160 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = '1' FROM t +-- !query 160 schema +struct<(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) = CAST(1 AS TIMESTAMP)):boolean> +-- !query 160 output +NULL + + +-- !query 161 +SELECT cast('2017-12-11 09:30:00' as date) = '1' FROM t +-- !query 161 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) = 1):boolean> +-- !query 161 output +false + + +-- !query 162 +SELECT '1' <=> cast(1 as tinyint) FROM t +-- !query 162 schema +struct<(CAST(1 AS TINYINT) <=> CAST(1 AS TINYINT)):boolean> +-- !query 162 output +true + + +-- !query 163 +SELECT '1' <=> cast(1 as smallint) FROM t +-- !query 163 schema +struct<(CAST(1 AS SMALLINT) <=> CAST(1 AS SMALLINT)):boolean> +-- !query 163 output +true + + +-- !query 164 +SELECT '1' <=> cast(1 as int) FROM t +-- !query 164 schema +struct<(CAST(1 AS INT) <=> CAST(1 AS INT)):boolean> +-- !query 164 output +true + + +-- !query 165 +SELECT '1' <=> cast(1 as bigint) FROM t +-- !query 165 schema +struct<(CAST(1 AS BIGINT) <=> CAST(1 AS BIGINT)):boolean> +-- !query 165 output +true + + +-- !query 166 +SELECT '1' <=> cast(1 as float) FROM t +-- !query 166 schema +struct<(CAST(1 AS FLOAT) <=> CAST(1 AS FLOAT)):boolean> +-- !query 166 output +true + + +-- !query 167 +SELECT '1' <=> cast(1 as double) FROM t +-- !query 167 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 167 output +true + + +-- !query 168 +SELECT '1' <=> cast(1 as decimal(10, 0)) FROM t +-- !query 168 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 168 output +true + + +-- !query 169 +SELECT '1' <=> '1' FROM t +-- !query 169 schema +struct<(1 <=> 1):boolean> +-- !query 169 output +true + + +-- !query 170 +SELECT '1' <=> cast('1' as binary) FROM t +-- !query 170 schema +struct<(CAST(1 AS BINARY) <=> CAST(1 AS BINARY)):boolean> +-- !query 170 output +true + + +-- !query 171 +SELECT '1' <=> cast(1 as boolean) FROM t +-- !query 171 schema +struct<(CAST(1 AS BOOLEAN) <=> CAST(1 AS BOOLEAN)):boolean> +-- !query 171 output +true + + +-- !query 172 +SELECT '1' <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 172 schema +struct<(CAST(1 AS TIMESTAMP) <=> CAST(2017-12-11 09:30:00.0 AS TIMESTAMP)):boolean> +-- !query 172 output +false + + +-- !query 173 +SELECT '1' <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 173 schema +struct<(1 <=> CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 173 output +false + + +-- !query 174 +SELECT cast(1 as tinyint) <=> '1' FROM t +-- !query 174 schema +struct<(CAST(1 AS TINYINT) <=> CAST(1 AS TINYINT)):boolean> +-- !query 174 output +true + + +-- !query 175 +SELECT cast(1 as smallint) <=> '1' FROM t +-- !query 175 schema +struct<(CAST(1 AS SMALLINT) <=> CAST(1 AS SMALLINT)):boolean> +-- !query 175 output +true + + +-- !query 176 +SELECT cast(1 as int) <=> '1' FROM t +-- !query 176 schema +struct<(CAST(1 AS INT) <=> CAST(1 AS INT)):boolean> +-- !query 176 output +true + + +-- !query 177 +SELECT cast(1 as bigint) <=> '1' FROM t +-- !query 177 schema +struct<(CAST(1 AS BIGINT) <=> CAST(1 AS BIGINT)):boolean> +-- !query 177 output +true + + +-- !query 178 +SELECT cast(1 as float) <=> '1' FROM t +-- !query 178 schema +struct<(CAST(1 AS FLOAT) <=> CAST(1 AS FLOAT)):boolean> +-- !query 178 output +true + + +-- !query 179 +SELECT cast(1 as double) <=> '1' FROM t +-- !query 179 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 179 output +true + + +-- !query 180 +SELECT cast(1 as decimal(10, 0)) <=> '1' FROM t +-- !query 180 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 180 output +true + + +-- !query 181 +SELECT cast('1' as binary) <=> '1' FROM t +-- !query 181 schema +struct<(CAST(1 AS BINARY) <=> CAST(1 AS BINARY)):boolean> +-- !query 181 output +true + + +-- !query 182 +SELECT cast(1 as boolean) <=> '1' FROM t +-- !query 182 schema +struct<(CAST(1 AS BOOLEAN) <=> CAST(1 AS BOOLEAN)):boolean> +-- !query 182 output +true + + +-- !query 183 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> '1' FROM t +-- !query 183 schema +struct<(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) <=> CAST(1 AS TIMESTAMP)):boolean> +-- !query 183 output +false + + +-- !query 184 +SELECT cast('2017-12-11 09:30:00' as date) <=> '1' FROM t +-- !query 184 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) <=> 1):boolean> +-- !query 184 output +false + + +-- !query 185 +SELECT '1' < cast(1 as tinyint) FROM t +-- !query 185 schema +struct<(CAST(1 AS TINYINT) < CAST(1 AS TINYINT)):boolean> +-- !query 185 output +false + + +-- !query 186 +SELECT '1' < cast(1 as smallint) FROM t +-- !query 186 schema +struct<(CAST(1 AS SMALLINT) < CAST(1 AS SMALLINT)):boolean> +-- !query 186 output +false + + +-- !query 187 +SELECT '1' < cast(1 as int) FROM t +-- !query 187 schema +struct<(CAST(1 AS INT) < CAST(1 AS INT)):boolean> +-- !query 187 output +false + + +-- !query 188 +SELECT '1' < cast(1 as bigint) FROM t +-- !query 188 schema +struct<(CAST(1 AS BIGINT) < CAST(1 AS BIGINT)):boolean> +-- !query 188 output +false + + +-- !query 189 +SELECT '1' < cast(1 as float) FROM t +-- !query 189 schema +struct<(CAST(1 AS FLOAT) < CAST(1 AS FLOAT)):boolean> +-- !query 189 output +false + + +-- !query 190 +SELECT '1' < cast(1 as double) FROM t +-- !query 190 schema +struct<(CAST(1 AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 190 output +false + + +-- !query 191 +SELECT '1' < cast(1 as decimal(10, 0)) FROM t +-- !query 191 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 191 output +false + + +-- !query 192 +SELECT '1' < '1' FROM t +-- !query 192 schema +struct<(1 < 1):boolean> +-- !query 192 output +false + + +-- !query 193 +SELECT '1' < cast('1' as binary) FROM t +-- !query 193 schema +struct<(CAST(1 AS BINARY) < CAST(1 AS BINARY)):boolean> +-- !query 193 output +false + + +-- !query 194 +SELECT '1' < cast(1 as boolean) FROM t +-- !query 194 schema +struct<(CAST(1 AS BOOLEAN) < CAST(1 AS BOOLEAN)):boolean> +-- !query 194 output +false + + +-- !query 195 +SELECT '1' < cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 195 schema +struct<(1 < CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING)):boolean> +-- !query 195 output +true + + +-- !query 196 +SELECT '1' < cast('2017-12-11 09:30:00' as date) FROM t +-- !query 196 schema +struct<(1 < CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 196 output +true + + +-- !query 197 +SELECT '1' <= cast(1 as tinyint) FROM t +-- !query 197 schema +struct<(CAST(1 AS TINYINT) <= CAST(1 AS TINYINT)):boolean> +-- !query 197 output +true + + +-- !query 198 +SELECT '1' <= cast(1 as smallint) FROM t +-- !query 198 schema +struct<(CAST(1 AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean> +-- !query 198 output +true + + +-- !query 199 +SELECT '1' <= cast(1 as int) FROM t +-- !query 199 schema +struct<(CAST(1 AS INT) <= CAST(1 AS INT)):boolean> +-- !query 199 output +true + + +-- !query 200 +SELECT '1' <= cast(1 as bigint) FROM t +-- !query 200 schema +struct<(CAST(1 AS BIGINT) <= CAST(1 AS BIGINT)):boolean> +-- !query 200 output +true + + +-- !query 201 +SELECT '1' <= cast(1 as float) FROM t +-- !query 201 schema +struct<(CAST(1 AS FLOAT) <= CAST(1 AS FLOAT)):boolean> +-- !query 201 output +true + + +-- !query 202 +SELECT '1' <= cast(1 as double) FROM t +-- !query 202 schema +struct<(CAST(1 AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 202 output +true + + +-- !query 203 +SELECT '1' <= cast(1 as decimal(10, 0)) FROM t +-- !query 203 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 203 output +true + + +-- !query 204 +SELECT '1' <= '1' FROM t +-- !query 204 schema +struct<(1 <= 1):boolean> +-- !query 204 output +true + + +-- !query 205 +SELECT '1' <= cast('1' as binary) FROM t +-- !query 205 schema +struct<(CAST(1 AS BINARY) <= CAST(1 AS BINARY)):boolean> +-- !query 205 output +true + + +-- !query 206 +SELECT '1' <= cast(1 as boolean) FROM t +-- !query 206 schema +struct<(CAST(1 AS BOOLEAN) <= CAST(1 AS BOOLEAN)):boolean> +-- !query 206 output +true + + +-- !query 207 +SELECT '1' <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 207 schema +struct<(1 <= CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING)):boolean> +-- !query 207 output +true + + +-- !query 208 +SELECT '1' <= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 208 schema +struct<(1 <= CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 208 output +true + + +-- !query 209 +SELECT '1' > cast(1 as tinyint) FROM t +-- !query 209 schema +struct<(CAST(1 AS TINYINT) > CAST(1 AS TINYINT)):boolean> +-- !query 209 output +false + + +-- !query 210 +SELECT '1' > cast(1 as smallint) FROM t +-- !query 210 schema +struct<(CAST(1 AS SMALLINT) > CAST(1 AS SMALLINT)):boolean> +-- !query 210 output +false + + +-- !query 211 +SELECT '1' > cast(1 as int) FROM t +-- !query 211 schema +struct<(CAST(1 AS INT) > CAST(1 AS INT)):boolean> +-- !query 211 output +false + + +-- !query 212 +SELECT '1' > cast(1 as bigint) FROM t +-- !query 212 schema +struct<(CAST(1 AS BIGINT) > CAST(1 AS BIGINT)):boolean> +-- !query 212 output +false + + +-- !query 213 +SELECT '1' > cast(1 as float) FROM t +-- !query 213 schema +struct<(CAST(1 AS FLOAT) > CAST(1 AS FLOAT)):boolean> +-- !query 213 output +false + + +-- !query 214 +SELECT '1' > cast(1 as double) FROM t +-- !query 214 schema +struct<(CAST(1 AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 214 output +false + + +-- !query 215 +SELECT '1' > cast(1 as decimal(10, 0)) FROM t +-- !query 215 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 215 output +false + + +-- !query 216 +SELECT '1' > '1' FROM t +-- !query 216 schema +struct<(1 > 1):boolean> +-- !query 216 output +false + + +-- !query 217 +SELECT '1' > cast('1' as binary) FROM t +-- !query 217 schema +struct<(CAST(1 AS BINARY) > CAST(1 AS BINARY)):boolean> +-- !query 217 output +false + + +-- !query 218 +SELECT '1' > cast(1 as boolean) FROM t +-- !query 218 schema +struct<(CAST(1 AS BOOLEAN) > CAST(1 AS BOOLEAN)):boolean> +-- !query 218 output +false + + +-- !query 219 +SELECT '1' > cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 219 schema +struct<(1 > CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING)):boolean> +-- !query 219 output +false + + +-- !query 220 +SELECT '1' > cast('2017-12-11 09:30:00' as date) FROM t +-- !query 220 schema +struct<(1 > CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 220 output +false + + +-- !query 221 +SELECT '1' >= cast(1 as tinyint) FROM t +-- !query 221 schema +struct<(CAST(1 AS TINYINT) >= CAST(1 AS TINYINT)):boolean> +-- !query 221 output +true + + +-- !query 222 +SELECT '1' >= cast(1 as smallint) FROM t +-- !query 222 schema +struct<(CAST(1 AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean> +-- !query 222 output +true + + +-- !query 223 +SELECT '1' >= cast(1 as int) FROM t +-- !query 223 schema +struct<(CAST(1 AS INT) >= CAST(1 AS INT)):boolean> +-- !query 223 output +true + + +-- !query 224 +SELECT '1' >= cast(1 as bigint) FROM t +-- !query 224 schema +struct<(CAST(1 AS BIGINT) >= CAST(1 AS BIGINT)):boolean> +-- !query 224 output +true + + +-- !query 225 +SELECT '1' >= cast(1 as float) FROM t +-- !query 225 schema +struct<(CAST(1 AS FLOAT) >= CAST(1 AS FLOAT)):boolean> +-- !query 225 output +true + + +-- !query 226 +SELECT '1' >= cast(1 as double) FROM t +-- !query 226 schema +struct<(CAST(1 AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 226 output +true + + +-- !query 227 +SELECT '1' >= cast(1 as decimal(10, 0)) FROM t +-- !query 227 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 227 output +true + + +-- !query 228 +SELECT '1' >= '1' FROM t +-- !query 228 schema +struct<(1 >= 1):boolean> +-- !query 228 output +true + + +-- !query 229 +SELECT '1' >= cast('1' as binary) FROM t +-- !query 229 schema +struct<(CAST(1 AS BINARY) >= CAST(1 AS BINARY)):boolean> +-- !query 229 output +true + + +-- !query 230 +SELECT '1' >= cast(1 as boolean) FROM t +-- !query 230 schema +struct<(CAST(1 AS BOOLEAN) >= CAST(1 AS BOOLEAN)):boolean> +-- !query 230 output +true + + +-- !query 231 +SELECT '1' >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 231 schema +struct<(1 >= CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING)):boolean> +-- !query 231 output +false + + +-- !query 232 +SELECT '1' >= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 232 schema +struct<(1 >= CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING)):boolean> +-- !query 232 output +false + + +-- !query 233 +SELECT '1' <> cast(1 as tinyint) FROM t +-- !query 233 schema +struct<(NOT (CAST(1 AS TINYINT) = CAST(1 AS TINYINT))):boolean> +-- !query 233 output +false + + +-- !query 234 +SELECT '1' <> cast(1 as smallint) FROM t +-- !query 234 schema +struct<(NOT (CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT))):boolean> +-- !query 234 output +false + + +-- !query 235 +SELECT '1' <> cast(1 as int) FROM t +-- !query 235 schema +struct<(NOT (CAST(1 AS INT) = CAST(1 AS INT))):boolean> +-- !query 235 output +false + + +-- !query 236 +SELECT '1' <> cast(1 as bigint) FROM t +-- !query 236 schema +struct<(NOT (CAST(1 AS BIGINT) = CAST(1 AS BIGINT))):boolean> +-- !query 236 output +false + + +-- !query 237 +SELECT '1' <> cast(1 as float) FROM t +-- !query 237 schema +struct<(NOT (CAST(1 AS FLOAT) = CAST(1 AS FLOAT))):boolean> +-- !query 237 output +false + + +-- !query 238 +SELECT '1' <> cast(1 as double) FROM t +-- !query 238 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 238 output +false + + +-- !query 239 +SELECT '1' <> cast(1 as decimal(10, 0)) FROM t +-- !query 239 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 239 output +false + + +-- !query 240 +SELECT '1' <> '1' FROM t +-- !query 240 schema +struct<(NOT (1 = 1)):boolean> +-- !query 240 output +false + + +-- !query 241 +SELECT '1' <> cast('1' as binary) FROM t +-- !query 241 schema +struct<(NOT (CAST(1 AS BINARY) = CAST(1 AS BINARY))):boolean> +-- !query 241 output +false + + +-- !query 242 +SELECT '1' <> cast(1 as boolean) FROM t +-- !query 242 schema +struct<(NOT (CAST(1 AS BOOLEAN) = CAST(1 AS BOOLEAN))):boolean> +-- !query 242 output +false + + +-- !query 243 +SELECT '1' <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 243 schema +struct<(NOT (CAST(1 AS TIMESTAMP) = CAST(2017-12-11 09:30:00.0 AS TIMESTAMP))):boolean> +-- !query 243 output +NULL + + +-- !query 244 +SELECT '1' <> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 244 schema +struct<(NOT (1 = CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):boolean> +-- !query 244 output +true + + +-- !query 245 +SELECT cast(1 as tinyint) < '1' FROM t +-- !query 245 schema +struct<(CAST(1 AS TINYINT) < CAST(1 AS TINYINT)):boolean> +-- !query 245 output +false + + +-- !query 246 +SELECT cast(1 as smallint) < '1' FROM t +-- !query 246 schema +struct<(CAST(1 AS SMALLINT) < CAST(1 AS SMALLINT)):boolean> +-- !query 246 output +false + + +-- !query 247 +SELECT cast(1 as int) < '1' FROM t +-- !query 247 schema +struct<(CAST(1 AS INT) < CAST(1 AS INT)):boolean> +-- !query 247 output +false + + +-- !query 248 +SELECT cast(1 as bigint) < '1' FROM t +-- !query 248 schema +struct<(CAST(1 AS BIGINT) < CAST(1 AS BIGINT)):boolean> +-- !query 248 output +false + + +-- !query 249 +SELECT cast(1 as float) < '1' FROM t +-- !query 249 schema +struct<(CAST(1 AS FLOAT) < CAST(1 AS FLOAT)):boolean> +-- !query 249 output +false + + +-- !query 250 +SELECT cast(1 as double) < '1' FROM t +-- !query 250 schema +struct<(CAST(1 AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 250 output +false + + +-- !query 251 +SELECT cast(1 as decimal(10, 0)) < '1' FROM t +-- !query 251 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 251 output +false + + +-- !query 252 +SELECT '1' < '1' FROM t +-- !query 252 schema +struct<(1 < 1):boolean> +-- !query 252 output +false + + +-- !query 253 +SELECT cast('1' as binary) < '1' FROM t +-- !query 253 schema +struct<(CAST(1 AS BINARY) < CAST(1 AS BINARY)):boolean> +-- !query 253 output +false + + +-- !query 254 +SELECT cast(1 as boolean) < '1' FROM t +-- !query 254 schema +struct<(CAST(1 AS BOOLEAN) < CAST(1 AS BOOLEAN)):boolean> +-- !query 254 output +false + + +-- !query 255 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < '1' FROM t +-- !query 255 schema +struct<(CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING) < 1):boolean> +-- !query 255 output +false + + +-- !query 256 +SELECT cast('2017-12-11 09:30:00' as date) < '1' FROM t +-- !query 256 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) < 1):boolean> +-- !query 256 output +false + + +-- !query 257 +SELECT cast(1 as tinyint) <= '1' FROM t +-- !query 257 schema +struct<(CAST(1 AS TINYINT) <= CAST(1 AS TINYINT)):boolean> +-- !query 257 output +true + + +-- !query 258 +SELECT cast(1 as smallint) <= '1' FROM t +-- !query 258 schema +struct<(CAST(1 AS SMALLINT) <= CAST(1 AS SMALLINT)):boolean> +-- !query 258 output +true + + +-- !query 259 +SELECT cast(1 as int) <= '1' FROM t +-- !query 259 schema +struct<(CAST(1 AS INT) <= CAST(1 AS INT)):boolean> +-- !query 259 output +true + + +-- !query 260 +SELECT cast(1 as bigint) <= '1' FROM t +-- !query 260 schema +struct<(CAST(1 AS BIGINT) <= CAST(1 AS BIGINT)):boolean> +-- !query 260 output +true + + +-- !query 261 +SELECT cast(1 as float) <= '1' FROM t +-- !query 261 schema +struct<(CAST(1 AS FLOAT) <= CAST(1 AS FLOAT)):boolean> +-- !query 261 output +true + + +-- !query 262 +SELECT cast(1 as double) <= '1' FROM t +-- !query 262 schema +struct<(CAST(1 AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 262 output +true + + +-- !query 263 +SELECT cast(1 as decimal(10, 0)) <= '1' FROM t +-- !query 263 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 263 output +true + + +-- !query 264 +SELECT '1' <= '1' FROM t +-- !query 264 schema +struct<(1 <= 1):boolean> +-- !query 264 output +true + + +-- !query 265 +SELECT cast('1' as binary) <= '1' FROM t +-- !query 265 schema +struct<(CAST(1 AS BINARY) <= CAST(1 AS BINARY)):boolean> +-- !query 265 output +true + + +-- !query 266 +SELECT cast(1 as boolean) <= '1' FROM t +-- !query 266 schema +struct<(CAST(1 AS BOOLEAN) <= CAST(1 AS BOOLEAN)):boolean> +-- !query 266 output +true + + +-- !query 267 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= '1' FROM t +-- !query 267 schema +struct<(CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING) <= 1):boolean> +-- !query 267 output +false + + +-- !query 268 +SELECT cast('2017-12-11 09:30:00' as date) <= '1' FROM t +-- !query 268 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) <= 1):boolean> +-- !query 268 output +false + + +-- !query 269 +SELECT cast(1 as tinyint) > '1' FROM t +-- !query 269 schema +struct<(CAST(1 AS TINYINT) > CAST(1 AS TINYINT)):boolean> +-- !query 269 output +false + + +-- !query 270 +SELECT cast(1 as smallint) > '1' FROM t +-- !query 270 schema +struct<(CAST(1 AS SMALLINT) > CAST(1 AS SMALLINT)):boolean> +-- !query 270 output +false + + +-- !query 271 +SELECT cast(1 as int) > '1' FROM t +-- !query 271 schema +struct<(CAST(1 AS INT) > CAST(1 AS INT)):boolean> +-- !query 271 output +false + + +-- !query 272 +SELECT cast(1 as bigint) > '1' FROM t +-- !query 272 schema +struct<(CAST(1 AS BIGINT) > CAST(1 AS BIGINT)):boolean> +-- !query 272 output +false + + +-- !query 273 +SELECT cast(1 as float) > '1' FROM t +-- !query 273 schema +struct<(CAST(1 AS FLOAT) > CAST(1 AS FLOAT)):boolean> +-- !query 273 output +false + + +-- !query 274 +SELECT cast(1 as double) > '1' FROM t +-- !query 274 schema +struct<(CAST(1 AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 274 output +false + + +-- !query 275 +SELECT cast(1 as decimal(10, 0)) > '1' FROM t +-- !query 275 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 275 output +false + + +-- !query 276 +SELECT '1' > '1' FROM t +-- !query 276 schema +struct<(1 > 1):boolean> +-- !query 276 output +false + + +-- !query 277 +SELECT cast('1' as binary) > '1' FROM t +-- !query 277 schema +struct<(CAST(1 AS BINARY) > CAST(1 AS BINARY)):boolean> +-- !query 277 output +false + + +-- !query 278 +SELECT cast(1 as boolean) > '1' FROM t +-- !query 278 schema +struct<(CAST(1 AS BOOLEAN) > CAST(1 AS BOOLEAN)):boolean> +-- !query 278 output +false + + +-- !query 279 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > '1' FROM t +-- !query 279 schema +struct<(CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING) > 1):boolean> +-- !query 279 output +true + + +-- !query 280 +SELECT cast('2017-12-11 09:30:00' as date) > '1' FROM t +-- !query 280 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) > 1):boolean> +-- !query 280 output +true + + +-- !query 281 +SELECT cast(1 as tinyint) >= '1' FROM t +-- !query 281 schema +struct<(CAST(1 AS TINYINT) >= CAST(1 AS TINYINT)):boolean> +-- !query 281 output +true + + +-- !query 282 +SELECT cast(1 as smallint) >= '1' FROM t +-- !query 282 schema +struct<(CAST(1 AS SMALLINT) >= CAST(1 AS SMALLINT)):boolean> +-- !query 282 output +true + + +-- !query 283 +SELECT cast(1 as int) >= '1' FROM t +-- !query 283 schema +struct<(CAST(1 AS INT) >= CAST(1 AS INT)):boolean> +-- !query 283 output +true + + +-- !query 284 +SELECT cast(1 as bigint) >= '1' FROM t +-- !query 284 schema +struct<(CAST(1 AS BIGINT) >= CAST(1 AS BIGINT)):boolean> +-- !query 284 output +true + + +-- !query 285 +SELECT cast(1 as float) >= '1' FROM t +-- !query 285 schema +struct<(CAST(1 AS FLOAT) >= CAST(1 AS FLOAT)):boolean> +-- !query 285 output +true + + +-- !query 286 +SELECT cast(1 as double) >= '1' FROM t +-- !query 286 schema +struct<(CAST(1 AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 286 output +true + + +-- !query 287 +SELECT cast(1 as decimal(10, 0)) >= '1' FROM t +-- !query 287 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 287 output +true + + +-- !query 288 +SELECT '1' >= '1' FROM t +-- !query 288 schema +struct<(1 >= 1):boolean> +-- !query 288 output +true + + +-- !query 289 +SELECT cast('1' as binary) >= '1' FROM t +-- !query 289 schema +struct<(CAST(1 AS BINARY) >= CAST(1 AS BINARY)):boolean> +-- !query 289 output +true + + +-- !query 290 +SELECT cast(1 as boolean) >= '1' FROM t +-- !query 290 schema +struct<(CAST(1 AS BOOLEAN) >= CAST(1 AS BOOLEAN)):boolean> +-- !query 290 output +true + + +-- !query 291 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= '1' FROM t +-- !query 291 schema +struct<(CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING) >= 1):boolean> +-- !query 291 output +true + + +-- !query 292 +SELECT cast('2017-12-11 09:30:00' as date) >= '1' FROM t +-- !query 292 schema +struct<(CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) >= 1):boolean> +-- !query 292 output +true + + +-- !query 293 +SELECT cast(1 as tinyint) <> '1' FROM t +-- !query 293 schema +struct<(NOT (CAST(1 AS TINYINT) = CAST(1 AS TINYINT))):boolean> +-- !query 293 output +false + + +-- !query 294 +SELECT cast(1 as smallint) <> '1' FROM t +-- !query 294 schema +struct<(NOT (CAST(1 AS SMALLINT) = CAST(1 AS SMALLINT))):boolean> +-- !query 294 output +false + + +-- !query 295 +SELECT cast(1 as int) <> '1' FROM t +-- !query 295 schema +struct<(NOT (CAST(1 AS INT) = CAST(1 AS INT))):boolean> +-- !query 295 output +false + + +-- !query 296 +SELECT cast(1 as bigint) <> '1' FROM t +-- !query 296 schema +struct<(NOT (CAST(1 AS BIGINT) = CAST(1 AS BIGINT))):boolean> +-- !query 296 output +false + + +-- !query 297 +SELECT cast(1 as float) <> '1' FROM t +-- !query 297 schema +struct<(NOT (CAST(1 AS FLOAT) = CAST(1 AS FLOAT))):boolean> +-- !query 297 output +false + + +-- !query 298 +SELECT cast(1 as double) <> '1' FROM t +-- !query 298 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 298 output +false + + +-- !query 299 +SELECT cast(1 as decimal(10, 0)) <> '1' FROM t +-- !query 299 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 299 output +false + + +-- !query 300 +SELECT '1' <> '1' FROM t +-- !query 300 schema +struct<(NOT (1 = 1)):boolean> +-- !query 300 output +false + + +-- !query 301 +SELECT cast('1' as binary) <> '1' FROM t +-- !query 301 schema +struct<(NOT (CAST(1 AS BINARY) = CAST(1 AS BINARY))):boolean> +-- !query 301 output +false + + +-- !query 302 +SELECT cast(1 as boolean) <> '1' FROM t +-- !query 302 schema +struct<(NOT (CAST(1 AS BOOLEAN) = CAST(1 AS BOOLEAN))):boolean> +-- !query 302 output +false + + +-- !query 303 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> '1' FROM t +-- !query 303 schema +struct<(NOT (CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) = CAST(1 AS TIMESTAMP))):boolean> +-- !query 303 output +NULL + + +-- !query 304 +SELECT cast('2017-12-11 09:30:00' as date) <> '1' FROM t +-- !query 304 schema +struct<(NOT (CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING) = 1)):boolean> +-- !query 304 output +true + + +-- !query 305 +SELECT abs('1') FROM t +-- !query 305 schema +struct +-- !query 305 output +1.0 + + +-- !query 306 +SELECT sum('1') FROM t +-- !query 306 schema +struct +-- !query 306 output +1.0 + + +-- !query 307 +SELECT avg('1') FROM t +-- !query 307 schema +struct +-- !query 307 output +1.0 + + +-- !query 308 +SELECT stddev_pop('1') FROM t +-- !query 308 schema +struct +-- !query 308 output +0.0 + + +-- !query 309 +SELECT stddev_samp('1') FROM t +-- !query 309 schema +struct +-- !query 309 output +NaN + + +-- !query 310 +SELECT - '1' FROM t +-- !query 310 schema +struct<(- CAST(1 AS DOUBLE)):double> +-- !query 310 output +-1.0 + + +-- !query 311 +SELECT + '1' FROM t +-- !query 311 schema +struct<1:string> +-- !query 311 output +1 + + +-- !query 312 +SELECT var_pop('1') FROM t +-- !query 312 schema +struct +-- !query 312 output +0.0 + + +-- !query 313 +SELECT var_samp('1') FROM t +-- !query 313 schema +struct +-- !query 313 output +NaN + + +-- !query 314 +SELECT skewness('1') FROM t +-- !query 314 schema +struct +-- !query 314 output +NaN + + +-- !query 315 +SELECT kurtosis('1') FROM t +-- !query 315 schema +struct +-- !query 315 output +NaN From fb3636b482be3d0940345b1528c1d5090bbc25e6 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Dec 2017 11:29:32 -0800 Subject: [PATCH 444/712] [SPARK-22807][SCHEDULER] Remove config that says docker and replace with container ## What changes were proposed in this pull request? Changes discussed in https://github.com/apache/spark/pull/19946#discussion_r157063535 docker -> container, since with CRI, we are not limited to running only docker images. ## How was this patch tested? Manual testing Author: foxish Closes #19995 from foxish/make-docker-container. --- .../apache/spark/deploy/SparkSubmitSuite.scala | 4 ++-- .../org/apache/spark/deploy/k8s/Config.scala | 17 ++++++++--------- .../DriverConfigurationStepsOrchestrator.scala | 4 ++-- .../steps/BaseDriverConfigurationStep.scala | 12 ++++++------ .../cluster/k8s/ExecutorPodFactory.scala | 12 ++++++------ ...verConfigurationStepsOrchestratorSuite.scala | 6 +++--- .../BaseDriverConfigurationStepSuite.scala | 8 ++++---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 2 +- 8 files changed, 32 insertions(+), 33 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2eb8a1fee104c..27dd435332348 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -396,7 +396,7 @@ class SparkSubmitSuite "--class", "org.SomeClass", "--driver-memory", "4g", "--conf", "spark.kubernetes.namespace=spark", - "--conf", "spark.kubernetes.driver.docker.image=bar", + "--conf", "spark.kubernetes.driver.container.image=bar", "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) @@ -412,7 +412,7 @@ class SparkSubmitSuite conf.get("spark.executor.memory") should be ("5g") conf.get("spark.driver.memory") should be ("4g") conf.get("spark.kubernetes.namespace") should be ("spark") - conf.get("spark.kubernetes.driver.docker.image") should be ("bar") + conf.get("spark.kubernetes.driver.container.image") should be ("bar") } test("handles confs with flag equivalents") { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index f35fb38798218..04aadb4b06af4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -30,21 +30,20 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") - val DRIVER_DOCKER_IMAGE = - ConfigBuilder("spark.kubernetes.driver.docker.image") - .doc("Docker image to use for the driver. Specify this using the standard Docker tag format.") + val DRIVER_CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.driver.container.image") + .doc("Container image to use for the driver.") .stringConf .createOptional - val EXECUTOR_DOCKER_IMAGE = - ConfigBuilder("spark.kubernetes.executor.docker.image") - .doc("Docker image to use for the executors. Specify this using the standard Docker tag " + - "format.") + val EXECUTOR_CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.container.image") + .doc("Container image to use for the executors.") .stringConf .createOptional - val DOCKER_IMAGE_PULL_POLICY = - ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + val CONTAINER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.container.image.pullPolicy") .doc("Kubernetes image pull policy. Valid values are Always, Never, and IfNotPresent.") .stringConf .checkValues(Set("Always", "Never", "IfNotPresent")) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala index c563fc5bfbadf..1411e6f40b468 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala @@ -49,7 +49,7 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") } - private val dockerImagePullPolicy = submissionSparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val imagePullPolicy = submissionSparkConf.get(CONTAINER_IMAGE_PULL_POLICY) private val jarsDownloadPath = submissionSparkConf.get(JARS_DOWNLOAD_LOCATION) private val filesDownloadPath = submissionSparkConf.get(FILES_DOWNLOAD_LOCATION) @@ -72,7 +72,7 @@ private[spark] class DriverConfigurationStepsOrchestrator( kubernetesAppId, kubernetesResourceNamePrefix, allDriverLabels, - dockerImagePullPolicy, + imagePullPolicy, appName, mainClass, appArgs, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala index ba2a11b9e6689..c335fcce4036e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala @@ -34,7 +34,7 @@ private[spark] class BaseDriverConfigurationStep( kubernetesAppId: String, kubernetesResourceNamePrefix: String, driverLabels: Map[String, String], - dockerImagePullPolicy: String, + imagePullPolicy: String, appName: String, mainClass: String, appArgs: Array[String], @@ -46,9 +46,9 @@ private[spark] class BaseDriverConfigurationStep( private val driverExtraClasspath = submissionSparkConf.get( DRIVER_CLASS_PATH) - private val driverDockerImage = submissionSparkConf - .get(DRIVER_DOCKER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver Docker image")) + private val driverContainerImage = submissionSparkConf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) // CPU settings private val driverCpuCores = submissionSparkConf.getOption("spark.driver.cores").getOrElse("1") @@ -110,8 +110,8 @@ private[spark] class BaseDriverConfigurationStep( val driverContainer = new ContainerBuilder(driverSpec.driverContainer) .withName(DRIVER_CONTAINER_NAME) - .withImage(driverDockerImage) - .withImagePullPolicy(dockerImagePullPolicy) + .withImage(driverContainerImage) + .withImagePullPolicy(imagePullPolicy) .addAllToEnv(driverCustomEnvs.asJava) .addToEnv(driverExtraClasspathEnv.toSeq: _*) .addNewEnv() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 9d8f3b912c33d..70226157dd68b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -72,10 +72,10 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - private val executorDockerImage = sparkConf - .get(EXECUTOR_DOCKER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor Docker image")) - private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val executorContainerImage = sparkConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) private val blockManagerPort = sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) @@ -166,8 +166,8 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) val executorContainer = new ContainerBuilder() .withName("executor") - .withImage(executorDockerImage) - .withImagePullPolicy(dockerImagePullPolicy) + .withImage(executorContainerImage) + .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) .addToLimits("memory", executorMemoryLimitQuantity) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala index c7291d49b465e..98f9f27da5cde 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config.DRIVER_DOCKER_IMAGE +import org.apache.spark.deploy.k8s.Config.DRIVER_CONTAINER_IMAGE import org.apache.spark.deploy.k8s.submit.steps._ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { @@ -32,7 +32,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { test("Base submission steps with a main app resource.") { val sparkConf = new SparkConf(false) - .set(DRIVER_DOCKER_IMAGE, DRIVER_IMAGE) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigurationStepsOrchestrator( NAMESPACE, @@ -54,7 +54,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false) - .set(DRIVER_DOCKER_IMAGE, DRIVER_IMAGE) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigurationStepsOrchestrator( NAMESPACE, APP_ID, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala index 83c5f98254829..f7c1b3142cf71 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala @@ -30,7 +30,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { private val APP_ID = "spark-app-id" private val RESOURCE_NAME_PREFIX = "spark" private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val DOCKER_IMAGE_PULL_POLICY = "IfNotPresent" + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2", "arg 3") @@ -47,7 +47,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_DOCKER_IMAGE, "spark-driver:latest") + .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") @@ -56,7 +56,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { APP_ID, RESOURCE_NAME_PREFIX, DRIVER_LABELS, - DOCKER_IMAGE_PULL_POLICY, + CONTAINER_IMAGE_PULL_POLICY, APP_NAME, MAIN_CLASS, APP_ARGS, @@ -71,7 +71,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === DOCKER_IMAGE_PULL_POLICY) + assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) assert(preparedDriverSpec.driverContainer.getEnv.size === 7) val envs = preparedDriverSpec.driverContainer diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 1c7717c238096..3a55d7cb37b1f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -50,7 +50,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_DOCKER_IMAGE, executorImage) + .set(EXECUTOR_CONTAINER_IMAGE, executorImage) } test("basic executor pod has reasonable defaults") { From 772e4648d95bda3353723337723543c741ea8476 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 18 Dec 2017 14:08:48 -0600 Subject: [PATCH 445/712] [SPARK-20653][CORE] Add cleaning of old elements from the status store. This change restores the functionality that keeps a limited number of different types (jobs, stages, etc) depending on configuration, to avoid the store growing indefinitely over time. The feature is implemented by creating a new type (ElementTrackingStore) that wraps a KVStore and allows triggers to be set up for when elements of a certain type meet a certain threshold. Triggers don't need to necessarily only delete elements, but the current API is set up in a way that makes that use case easier. The new store also has a trigger for the "close" call, which makes it easier for listeners to register code for cleaning things up and flushing partial state to the store. The old configurations for cleaning up the stored elements from the core and SQL UIs are now active again, and the old unit tests are re-enabled. Author: Marcelo Vanzin Closes #19751 from vanzin/SPARK-20653. --- .../deploy/history/FsHistoryProvider.scala | 17 +- .../spark/internal/config/package.scala | 5 - .../spark/status/AppStatusListener.scala | 188 ++++++++++++++++-- .../apache/spark/status/AppStatusPlugin.scala | 2 +- .../apache/spark/status/AppStatusStore.scala | 6 +- .../spark/status/ElementTrackingStore.scala | 160 +++++++++++++++ .../org/apache/spark/status/KVUtils.scala | 14 ++ .../org/apache/spark/status/LiveEntity.scala | 13 +- .../spark/status/api/v1/StagesResource.scala | 10 +- .../org/apache/spark/status/config.scala | 20 ++ .../org/apache/spark/status/storeTypes.scala | 16 ++ .../scala/org/apache/spark/ui/SparkUI.scala | 2 - .../apache/spark/ui/jobs/AllJobsPage.scala | 8 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 8 +- .../deploy/history/HistoryServerSuite.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 148 +++++++++++--- .../status/ElementTrackingStoreSuite.scala | 91 +++++++++ .../org/apache/spark/ui/UISeleniumSuite.scala | 8 +- .../spark/sql/internal/StaticSQLConf.scala | 7 + .../execution/ui/SQLAppStatusListener.scala | 33 ++- .../sql/execution/ui/SQLAppStatusStore.scala | 7 +- .../sql/execution/ui/SQLListenerSuite.scala | 29 +-- 22 files changed, 713 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala create mode 100644 core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a8e1becc56ab7..fa2c5194aa41b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} +import org.apache.spark.status.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -304,6 +305,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val (kvstore, needReplay) = uiStorePath match { case Some(path) => try { + // The store path is not guaranteed to exist - maybe it hasn't been created, or was + // invalidated because changes to the event log were detected. Need to replay in that + // case. val _replay = !path.isDirectory() (createDiskStore(path, conf), _replay) } catch { @@ -318,24 +322,23 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) (new InMemoryStore(), true) } + val trackingStore = new ElementTrackingStore(kvstore, conf) if (needReplay) { val replayBus = new ReplayListenerBus() - val listener = new AppStatusListener(kvstore, conf, false, + val listener = new AppStatusListener(trackingStore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(listener) AppStatusPlugin.loadPlugins().foreach { plugin => - plugin.setupListeners(conf, kvstore, l => replayBus.addListener(l), false) + plugin.setupListeners(conf, trackingStore, l => replayBus.addListener(l), false) } try { val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - listener.flush() + trackingStore.close(false) } catch { case e: Exception => - try { - kvstore.close() - } catch { - case _e: Exception => logInfo("Error closing store.", _e) + Utils.tryLogNonFatalError { + trackingStore.close() } uiStorePath.foreach(Utils.deleteRecursively) if (e.isInstanceOf[FileNotFoundException]) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 172ba85359da7..eb12ddf961314 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -240,11 +240,6 @@ package object config { .stringConf .createOptional - // To limit memory usage, we only track information for a fixed number of tasks - private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") - .intConf - .createWithDefault(100000) - // To limit how many applications are shown in the History Server summary ui private[spark] val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 6da44cbc44c4d..1fb7b76d43d04 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -32,7 +32,6 @@ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.scope._ -import org.apache.spark.util.kvstore.KVStore /** * A Spark listener that writes application information to a data store. The types written to the @@ -42,7 +41,7 @@ import org.apache.spark.util.kvstore.KVStore * unfinished tasks can be more accurately calculated (see SPARK-21922). */ private[spark] class AppStatusListener( - kvstore: KVStore, + kvstore: ElementTrackingStore, conf: SparkConf, live: Boolean, lastUpdateTime: Option[Long] = None) extends SparkListener with Logging { @@ -51,6 +50,7 @@ private[spark] class AppStatusListener( private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null + private var appSummary = new AppSummary(0, 0) private var coresPerTask: Int = 1 // How often to update live entities. -1 means "never update" when replaying applications, @@ -58,6 +58,7 @@ private[spark] class AppStatusListener( // operations that we can live without when rapidly processing incoming task events. private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + private val maxTasksPerStage = conf.get(MAX_RETAINED_TASKS_PER_STAGE) private val maxGraphRootNodes = conf.get(MAX_RETAINED_ROOT_NODES) // Keep track of live entities, so that task metrics can be efficiently updated (without @@ -68,10 +69,25 @@ private[spark] class AppStatusListener( private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() + // Keep the active executor count as a separate variable to avoid having to do synchronization + // around liveExecutors. + @volatile private var activeExecutorCount = 0 - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerLogStart(version) => sparkVersion = version - case _ => + kvstore.addTrigger(classOf[ExecutorSummaryWrapper], conf.get(MAX_RETAINED_DEAD_EXECUTORS)) + { count => cleanupExecutors(count) } + + kvstore.addTrigger(classOf[JobDataWrapper], conf.get(MAX_RETAINED_JOBS)) { count => + cleanupJobs(count) + } + + kvstore.addTrigger(classOf[StageDataWrapper], conf.get(MAX_RETAINED_STAGES)) { count => + cleanupStages(count) + } + + kvstore.onFlush { + if (!live) { + flush() + } } override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { @@ -97,6 +113,7 @@ private[spark] class AppStatusListener( Seq(attempt)) kvstore.write(new ApplicationInfoWrapper(appInfo)) + kvstore.write(appSummary) } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { @@ -158,10 +175,11 @@ private[spark] class AppStatusListener( override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { liveExecutors.remove(event.executorId).foreach { exec => val now = System.nanoTime() + activeExecutorCount = math.max(0, activeExecutorCount - 1) exec.isActive = false exec.removeTime = new Date(event.time) exec.removeReason = event.reason - update(exec, now) + update(exec, now, last = true) // Remove all RDD distributions that reference the removed executor, in case there wasn't // a corresponding event. @@ -290,8 +308,11 @@ private[spark] class AppStatusListener( } job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None - update(job, now) + update(job, now, last = true) } + + appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages) + kvstore.write(appSummary) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -350,6 +371,13 @@ private[spark] class AppStatusListener( job.activeTasks += 1 maybeUpdate(job, now) } + + if (stage.savedTasks.incrementAndGet() > maxTasksPerStage && !stage.cleaning) { + stage.cleaning = true + kvstore.doAsync { + cleanupTasks(stage) + } + } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -449,6 +477,13 @@ private[spark] class AppStatusListener( esummary.metrics.update(metricsDelta) } maybeUpdate(esummary, now) + + if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { + stage.cleaning = true + kvstore.doAsync { + cleanupTasks(stage) + } + } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -516,8 +551,11 @@ private[spark] class AppStatusListener( } stage.executorSummaries.values.foreach(update(_, now)) - update(stage, now) + update(stage, now, last = true) } + + appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) + kvstore.write(appSummary) } override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { @@ -573,7 +611,7 @@ private[spark] class AppStatusListener( } /** Flush all live entities' data to the underlying store. */ - def flush(): Unit = { + private def flush(): Unit = { val now = System.nanoTime() liveStages.values.asScala.foreach { stage => update(stage, now) @@ -708,7 +746,10 @@ private[spark] class AppStatusListener( } private def getOrCreateExecutor(executorId: String, addTime: Long): LiveExecutor = { - liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId, addTime)) + liveExecutors.getOrElseUpdate(executorId, { + activeExecutorCount += 1 + new LiveExecutor(executorId, addTime) + }) } private def updateStreamBlock(event: SparkListenerBlockUpdated, stream: StreamBlockId): Unit = { @@ -754,8 +795,8 @@ private[spark] class AppStatusListener( } } - private def update(entity: LiveEntity, now: Long): Unit = { - entity.write(kvstore, now) + private def update(entity: LiveEntity, now: Long, last: Boolean = false): Unit = { + entity.write(kvstore, now, checkTriggers = last) } /** Update a live entity only if it hasn't been updated in the last configured period. */ @@ -772,4 +813,127 @@ private[spark] class AppStatusListener( } } + private def cleanupExecutors(count: Long): Unit = { + // Because the limit is on the number of *dead* executors, we need to calculate whether + // there are actually enough dead executors to be deleted. + val threshold = conf.get(MAX_RETAINED_DEAD_EXECUTORS) + val dead = count - activeExecutorCount + + if (dead > threshold) { + val countToDelete = calculateNumberToRemove(dead, threshold) + val toDelete = kvstore.view(classOf[ExecutorSummaryWrapper]).index("active") + .max(countToDelete).first(false).last(false).asScala.toSeq + toDelete.foreach { e => kvstore.delete(e.getClass(), e.info.id) } + } + } + + private def cleanupJobs(count: Long): Unit = { + val countToDelete = calculateNumberToRemove(count, conf.get(MAX_RETAINED_JOBS)) + if (countToDelete <= 0L) { + return + } + + val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[JobDataWrapper]), + countToDelete.toInt) { j => + j.info.status != JobExecutionStatus.RUNNING && j.info.status != JobExecutionStatus.UNKNOWN + } + toDelete.foreach { j => kvstore.delete(j.getClass(), j.info.jobId) } + } + + private def cleanupStages(count: Long): Unit = { + val countToDelete = calculateNumberToRemove(count, conf.get(MAX_RETAINED_STAGES)) + if (countToDelete <= 0L) { + return + } + + val stages = KVUtils.viewToSeq(kvstore.view(classOf[StageDataWrapper]), + countToDelete.toInt) { s => + s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING + } + + stages.foreach { s => + val key = s.id + kvstore.delete(s.getClass(), key) + + val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) + .index("stage") + .first(key) + .last(key) + .asScala + .toSeq + execSummaries.foreach { e => + kvstore.delete(e.getClass(), e.id) + } + + val tasks = kvstore.view(classOf[TaskDataWrapper]) + .index("stage") + .first(key) + .last(key) + .asScala + + tasks.foreach { t => + kvstore.delete(t.getClass(), t.info.taskId) + } + + // Check whether there are remaining attempts for the same stage. If there aren't, then + // also delete the RDD graph data. + val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) + .index("stageId") + .first(s.stageId) + .last(s.stageId) + .closeableIterator() + + val hasMoreAttempts = try { + remainingAttempts.asScala.exists { other => + other.info.attemptId != s.info.attemptId + } + } finally { + remainingAttempts.close() + } + + if (!hasMoreAttempts) { + kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + } + } + } + + private def cleanupTasks(stage: LiveStage): Unit = { + val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt + if (countToDelete > 0) { + val stageKey = Array(stage.info.stageId, stage.info.attemptId) + val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) + .last(stageKey) + + // Try to delete finished tasks only. + val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => + !live || t.info.status != TaskState.RUNNING.toString() + } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + stage.savedTasks.addAndGet(-toDelete.size) + + // If there are more running tasks than the configured limit, delete running tasks. This + // should be extremely rare since the limit should generally far exceed the number of tasks + // that can run in parallel. + val remaining = countToDelete - toDelete.size + if (remaining > 0) { + val runningTasksToDelete = view.max(remaining).iterator().asScala.toList + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + stage.savedTasks.addAndGet(-remaining) + } + } + stage.cleaning = false + } + + /** + * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done + * asynchronously, this method may return 0 in case enough items have been deleted already. + */ + private def calculateNumberToRemove(dataSize: Long, retainedSize: Long): Long = { + if (dataSize > retainedSize) { + math.max(retainedSize / 10L, dataSize - retainedSize) + } else { + 0L + } + } + } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala index 69ca02ec76293..4cada5c7b0de4 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala @@ -48,7 +48,7 @@ private[spark] trait AppStatusPlugin { */ def setupListeners( conf: SparkConf, - store: KVStore, + store: ElementTrackingStore, addListenerFn: SparkListener => Unit, live: Boolean): Unit diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 22d768b3cb990..9987419b170f6 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -330,6 +330,10 @@ private[spark] class AppStatusStore( store.read(classOf[PoolData], name) } + def appSummary(): AppSummary = { + store.read(classOf[AppSummary], classOf[AppSummary].getName()) + } + def close(): Unit = { store.close() } @@ -347,7 +351,7 @@ private[spark] object AppStatusStore { * @param addListenerFn Function to register a listener with a bus. */ def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { - val store = new InMemoryStore() + val store = new ElementTrackingStore(new InMemoryStore(), conf) val listener = new AppStatusListener(store, conf, true) addListenerFn(listener) AppStatusPlugin.loadPlugins().foreach { p => diff --git a/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala new file mode 100644 index 0000000000000..863b0967f765e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.{HashMap, ListBuffer} + +import com.google.common.util.concurrent.MoreExecutors + +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.kvstore._ + +/** + * A KVStore wrapper that allows tracking the number of elements of specific types, and triggering + * actions once they reach a threshold. This allows writers, for example, to control how much data + * is stored by potentially deleting old data as new data is added. + * + * This store is used when populating data either from a live UI or an event log. On top of firing + * triggers when elements reach a certain threshold, it provides two extra bits of functionality: + * + * - a generic worker thread that can be used to run expensive tasks asynchronously; the tasks can + * be configured to run on the calling thread when more determinism is desired (e.g. unit tests). + * - a generic flush mechanism so that listeners can be notified about when they should flush + * internal state to the store (e.g. after the SHS finishes parsing an event log). + * + * The configured triggers are run on a separate thread by default; they can be forced to run on + * the calling thread by setting the `ASYNC_TRACKING_ENABLED` configuration to `false`. + */ +private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) extends KVStore { + + import config._ + + private val triggers = new HashMap[Class[_], Seq[Trigger[_]]]() + private val flushTriggers = new ListBuffer[() => Unit]() + private val executor = if (conf.get(ASYNC_TRACKING_ENABLED)) { + ThreadUtils.newDaemonSingleThreadExecutor("element-tracking-store-worker") + } else { + MoreExecutors.sameThreadExecutor() + } + + @volatile private var stopped = false + + /** + * Register a trigger that will be fired once the number of elements of a given type reaches + * the given threshold. + * + * @param klass The type to monitor. + * @param threshold The number of elements that should trigger the action. + * @param action Action to run when the threshold is reached; takes as a parameter the number + * of elements of the registered type currently known to be in the store. + */ + def addTrigger(klass: Class[_], threshold: Long)(action: Long => Unit): Unit = { + val existing = triggers.getOrElse(klass, Seq()) + triggers(klass) = existing :+ Trigger(threshold, action) + } + + /** + * Adds a trigger to be executed before the store is flushed. This normally happens before + * closing, and is useful for flushing intermediate state to the store, e.g. when replaying + * in-progress applications through the SHS. + * + * Flush triggers are called synchronously in the same thread that is closing the store. + */ + def onFlush(action: => Unit): Unit = { + flushTriggers += { () => action } + } + + /** + * Enqueues an action to be executed asynchronously. The task will run on the calling thread if + * `ASYNC_TRACKING_ENABLED` is `false`. + */ + def doAsync(fn: => Unit): Unit = { + executor.submit(new Runnable() { + override def run(): Unit = Utils.tryLog { fn } + }) + } + + override def read[T](klass: Class[T], naturalKey: Any): T = store.read(klass, naturalKey) + + override def write(value: Any): Unit = store.write(value) + + /** Write an element to the store, optionally checking for whether to fire triggers. */ + def write(value: Any, checkTriggers: Boolean): Unit = { + write(value) + + if (checkTriggers && !stopped) { + triggers.get(value.getClass()).foreach { list => + doAsync { + val count = store.count(value.getClass()) + list.foreach { t => + if (count > t.threshold) { + t.action(count) + } + } + } + } + } + } + + override def delete(klass: Class[_], naturalKey: Any): Unit = store.delete(klass, naturalKey) + + override def getMetadata[T](klass: Class[T]): T = store.getMetadata(klass) + + override def setMetadata(value: Any): Unit = store.setMetadata(value) + + override def view[T](klass: Class[T]): KVStoreView[T] = store.view(klass) + + override def count(klass: Class[_]): Long = store.count(klass) + + override def count(klass: Class[_], index: String, indexedValue: Any): Long = { + store.count(klass, index, indexedValue) + } + + override def close(): Unit = { + close(true) + } + + /** A close() method that optionally leaves the parent store open. */ + def close(closeParent: Boolean): Unit = synchronized { + if (stopped) { + return + } + + stopped = true + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + + flushTriggers.foreach { trigger => + Utils.tryLog(trigger()) + } + + if (closeParent) { + store.close() + } + } + + private case class Trigger[T]( + threshold: Long, + action: Long => Unit) + +} diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala index 4638511944c61..99b1843d8e1c0 100644 --- a/core/src/main/scala/org/apache/spark/status/KVUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.io.File import scala.annotation.meta.getter +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} @@ -68,6 +69,19 @@ private[spark] object KVUtils extends Logging { db } + /** Turns a KVStoreView into a Scala sequence, applying a filter. */ + def viewToSeq[T]( + view: KVStoreView[T], + max: Int) + (filter: T => Boolean): Seq[T] = { + val iter = view.closeableIterator() + try { + iter.asScala.filter(filter).take(max).toList + } finally { + iter.close() + } + } + private[spark] class MetadataMismatchException extends Exception } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 983c58a607aa8..52e83f250d34e 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -18,6 +18,7 @@ package org.apache.spark.status import java.util.Date +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap @@ -38,10 +39,12 @@ import org.apache.spark.util.kvstore.KVStore */ private[spark] abstract class LiveEntity { - var lastWriteTime = 0L + var lastWriteTime = -1L - def write(store: KVStore, now: Long): Unit = { - store.write(doUpdate()) + def write(store: ElementTrackingStore, now: Long, checkTriggers: Boolean = false): Unit = { + // Always check triggers on the first write, since adding an element to the store may + // cause the maximum count for the element type to be exceeded. + store.write(doUpdate(), checkTriggers || lastWriteTime == -1L) lastWriteTime = now } @@ -403,6 +406,10 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. + @volatile var cleaning = false + var savedTasks = new AtomicInteger(0) + def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index b3561109bc636..3b879545b3d2e 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -59,7 +59,15 @@ private[v1] class StagesResource extends BaseAppResource { ui.store.stageAttempt(stageId, stageAttemptId, details = details) } catch { case _: NoSuchElementException => - throw new NotFoundException(s"unknown attempt $stageAttemptId for stage $stageId.") + // Change the message depending on whether there are any attempts for the requested stage. + val all = ui.store.stageData(stageId) + val msg = if (all.nonEmpty) { + val ids = all.map(_.attemptId) + s"unknown attempt for stage $stageId. Found attempts: [${ids.mkString(",")}]" + } else { + s"unknown stage: $stageId" + } + throw new NotFoundException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala index 7af9dff977a86..67801b8f046f4 100644 --- a/core/src/main/scala/org/apache/spark/status/config.scala +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -23,10 +23,30 @@ import org.apache.spark.internal.config._ private[spark] object config { + val ASYNC_TRACKING_ENABLED = ConfigBuilder("spark.appStateStore.asyncTracking.enable") + .booleanConf + .createWithDefault(true) + val LIVE_ENTITY_UPDATE_PERIOD = ConfigBuilder("spark.ui.liveUpdate.period") .timeConf(TimeUnit.NANOSECONDS) .createWithDefaultString("100ms") + val MAX_RETAINED_JOBS = ConfigBuilder("spark.ui.retainedJobs") + .intConf + .createWithDefault(1000) + + val MAX_RETAINED_STAGES = ConfigBuilder("spark.ui.retainedStages") + .intConf + .createWithDefault(1000) + + val MAX_RETAINED_TASKS_PER_STAGE = ConfigBuilder("spark.ui.retainedTasks") + .intConf + .createWithDefault(100000) + + val MAX_RETAINED_DEAD_EXECUTORS = ConfigBuilder("spark.ui.retainedDeadExecutors") + .intConf + .createWithDefault(100) + val MAX_RETAINED_ROOT_NODES = ConfigBuilder("spark.ui.dagGraph.retainedRootRDDs") .intConf .createWithDefault(Int.MaxValue) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index c1ea87542d6cc..d9ead0071d3bf 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -112,6 +112,9 @@ private[spark] class TaskDataWrapper( Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) } + @JsonIgnore @KVIndex("active") + def active: Boolean = info.duration.isEmpty + } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { @@ -187,3 +190,16 @@ private[spark] class RDDOperationGraphWrapper( private[spark] class PoolData( @KVIndexParam val name: String, val stageIds: Set[Int]) + +/** + * A class with information about an app, to be used by the UI. There's only one instance of + * this summary per application, so its ID in the store is the class name. + */ +private[spark] class AppSummary( + val numCompletedJobs: Int, + val numCompletedStages: Int) { + + @KVIndex + def id: String = classOf[AppSummary].getName() + +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 35da3c3bfd1a2..b44ac0ea1febc 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -154,8 +154,6 @@ private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" val DEFAULT_POOL_NAME = "default" - val DEFAULT_RETAINED_STAGES = 1000 - val DEFAULT_RETAINED_JOBS = 1000 def getUIPort(conf: SparkConf): Int = { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index b60d39b21b4bf..37e3b3b304a63 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -300,7 +300,13 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val shouldShowCompletedJobs = completedJobs.nonEmpty val shouldShowFailedJobs = failedJobs.nonEmpty - val completedJobNumStr = s"${completedJobs.size}" + val appSummary = store.appSummary() + val completedJobNumStr = if (completedJobs.size == appSummary.numCompletedJobs) { + s"${completedJobs.size}" + } else { + s"${appSummary.numCompletedJobs}, only showing ${completedJobs.size}" + } + val schedulingMode = store.environmentInfo().sparkProperties.toMap .get("spark.scheduler.mode") .map { mode => SchedulingMode.withName(mode).toString } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index e4cf99e7b9e04..b1e343451e28e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -39,7 +39,6 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - val numCompletedStages = completedStages.size val numFailedStages = failedStages.size val subPath = "stages" @@ -69,10 +68,11 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val shouldShowCompletedStages = completedStages.nonEmpty val shouldShowFailedStages = failedStages.nonEmpty - val completedStageNumStr = if (numCompletedStages == completedStages.size) { - s"$numCompletedStages" + val appSummary = parent.store.appSummary() + val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { + s"${appSummary.numCompletedStages}" } else { - s"$numCompletedStages, only showing ${completedStages.size}" + s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" } val summary: NodeSeq = diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3a9790cd57270..3738f85da5831 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -264,7 +264,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val badStageAttemptId = getContentAndCode("applications/local-1422981780767/stages/1/1") badStageAttemptId._1 should be (HttpServletResponse.SC_NOT_FOUND) - badStageAttemptId._3 should be (Some("unknown attempt 1 for stage 1.")) + badStageAttemptId._3 should be (Some("unknown attempt for stage 1. Found attempts: [0]")) val badStageId2 = getContentAndCode("applications/local-1422981780767/stages/flimflam") badStageId2._1 should be (HttpServletResponse.SC_NOT_FOUND) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 88fe6bd70a14e..9cf4f7efb24a8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -39,16 +39,20 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { import config._ - private val conf = new SparkConf().set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + private val conf = new SparkConf() + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + .set(ASYNC_TRACKING_ENABLED, false) private var time: Long = _ private var testDir: File = _ - private var store: KVStore = _ + private var store: ElementTrackingStore = _ + private var taskIdTracker = -1L before { time = 0L testDir = Utils.createTempDir() - store = KVUtils.open(testDir, getClass().getName()) + store = new ElementTrackingStore(KVUtils.open(testDir, getClass().getName()), conf) + taskIdTracker = -1L } after { @@ -185,22 +189,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start tasks from stage 1 time += 1 - var _taskIdTracker = -1L - def nextTaskId(): Long = { - _taskIdTracker += 1 - _taskIdTracker - } - - def createTasks(count: Int, time: Long): Seq[TaskInfo] = { - (1 to count).map { id => - val exec = execIds(id.toInt % execIds.length) - val taskId = nextTaskId() - new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", - TaskLocality.PROCESS_LOCAL, id % 2 == 0) - } - } - val s1Tasks = createTasks(4, time) + val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) } @@ -419,7 +409,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start and fail all tasks of stage 2. time += 1 - val s2Tasks = createTasks(4, time) + val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) } @@ -470,7 +460,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { listener.onStageSubmitted(SparkListenerStageSubmitted(newS2, jobProps)) assert(store.count(classOf[StageDataWrapper]) === 3) - val newS2Tasks = createTasks(4, time) + val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) @@ -526,7 +516,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 5) time += 1 - val j2s2Tasks = createTasks(4, time) + val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, @@ -587,8 +577,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } // Stop executors. - listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "1", "Test")) - listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "2", "Test")) + time += 1 + listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "1", "Test")) + listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "2", "Test")) Seq("1", "2").foreach { id => check[ExecutorSummaryWrapper](id) { exec => @@ -851,6 +842,103 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("eviction of old data") { + val testConf = conf.clone() + .set(MAX_RETAINED_JOBS, 2) + .set(MAX_RETAINED_STAGES, 2) + .set(MAX_RETAINED_TASKS_PER_STAGE, 2) + .set(MAX_RETAINED_DEAD_EXECUTORS, 1) + val listener = new AppStatusListener(store, testConf, true) + + // Start 3 jobs, all should be kept. Stop one, it should be evicted. + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Nil, null)) + listener.onJobStart(SparkListenerJobStart(2, time, Nil, null)) + listener.onJobStart(SparkListenerJobStart(3, time, Nil, null)) + assert(store.count(classOf[JobDataWrapper]) === 3) + + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + assert(store.count(classOf[JobDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[JobDataWrapper], 2) + } + + // Start 3 stages, all should be kept. Stop 2 of them, the stopped one with the lowest id should + // be deleted. Start a new attempt of the second stopped one, and verify that the stage graph + // data is not deleted. + time += 1 + val stages = Seq( + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2"), + new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")) + + // Graph data is generated by the job start event, so fire it. + listener.onJobStart(SparkListenerJobStart(4, time, stages, null)) + + stages.foreach { s => + time += 1 + s.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(s, new Properties())) + } + + assert(store.count(classOf[StageDataWrapper]) === 3) + assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + + stages.drop(1).foreach { s => + time += 1 + s.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(s)) + } + + assert(store.count(classOf[StageDataWrapper]) === 2) + assert(store.count(classOf[RDDOperationGraphWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + + val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") + time += 1 + attempt2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(attempt2, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + assert(store.count(classOf[RDDOperationGraphWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(3, 0)) + } + store.read(classOf[StageDataWrapper], Array(3, 1)) + + // Start 2 tasks. Finish the second one. + time += 1 + val tasks = createTasks(2, Array("1")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + } + assert(store.count(classOf[TaskDataWrapper]) === 2) + + // Start a 3rd task. The finished tasks should be deleted. + createTasks(1, Array("1")).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + } + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks.last.id) + } + + // Start a 4th task. The first task should be deleted, even if it's still running. + createTasks(1, Array("1")).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + } + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks.head.id) + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { @@ -864,6 +952,20 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) } + private def createTasks(count: Int, execs: Array[String]): Seq[TaskInfo] = { + (1 to count).map { id => + val exec = execs(id.toInt % execs.length) + val taskId = nextTaskId() + new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", + TaskLocality.PROCESS_LOCAL, id % 2 == 0) + } + } + + private def nextTaskId(): Long = { + taskIdTracker += 1 + taskIdTracker + } + private case class RddBlock( rddId: Int, partId: Int, diff --git a/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala new file mode 100644 index 0000000000000..07a7b58404c29 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.mockito.Mockito._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.kvstore._ + +class ElementTrackingStoreSuite extends SparkFunSuite { + + import config._ + + test("tracking for multiple types") { + val store = mock(classOf[KVStore]) + val tracking = new ElementTrackingStore(store, new SparkConf() + .set(ASYNC_TRACKING_ENABLED, false)) + + var type1 = 0L + var type2 = 0L + var flushed = false + + tracking.addTrigger(classOf[Type1], 100) { count => + type1 = count + } + tracking.addTrigger(classOf[Type2], 1000) { count => + type2 = count + } + tracking.onFlush { + flushed = true + } + + when(store.count(classOf[Type1])).thenReturn(1L) + tracking.write(new Type1, true) + assert(type1 === 0L) + assert(type2 === 0L) + + when(store.count(classOf[Type1])).thenReturn(100L) + tracking.write(new Type1, true) + assert(type1 === 0L) + assert(type2 === 0L) + + when(store.count(classOf[Type1])).thenReturn(101L) + tracking.write(new Type1, true) + assert(type1 === 101L) + assert(type2 === 0L) + + when(store.count(classOf[Type1])).thenReturn(200L) + tracking.write(new Type1, true) + assert(type1 === 200L) + assert(type2 === 0L) + + when(store.count(classOf[Type2])).thenReturn(500L) + tracking.write(new Type2, true) + assert(type1 === 200L) + assert(type2 === 0L) + + when(store.count(classOf[Type2])).thenReturn(1000L) + tracking.write(new Type2, true) + assert(type1 === 200L) + assert(type2 === 0L) + + when(store.count(classOf[Type2])).thenReturn(2000L) + tracking.write(new Type2, true) + assert(type1 === 200L) + assert(type2 === 2000L) + + tracking.close(false) + assert(flushed) + verify(store, never()).close() + } + + private class Type1 + private class Type2 + +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index df5f0b5335e82..326546787ab6c 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} +import org.apache.spark.status.config._ private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -525,14 +526,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - ignore("stage & job retention") { + test("stage & job retention") { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") - .set("spark.ui.retainedStages", "3") - .set("spark.ui.retainedJobs", "2") + .set(MAX_RETAINED_STAGES, 3) + .set(MAX_RETAINED_JOBS, 2) + .set(ASYNC_TRACKING_ENABLED, false) val sc = new SparkContext(conf) assert(sc.ui.isDefined) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index c018fc8a332fa..fe0ad39c29025 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -95,4 +95,11 @@ object StaticSQLConf { .stringConf .toSequence .createOptional + + val UI_RETAINED_EXECUTIONS = + buildStaticConf("spark.sql.ui.retainedExecutions") + .doc("Number of executions to retain in the Spark UI.") + .intConf + .createWithDefault(1000) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 43cec4807ae4d..cf0000c6393a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -27,14 +27,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ -import org.apache.spark.status.LiveEntity +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} import org.apache.spark.status.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.kvstore.KVStore private[sql] class SQLAppStatusListener( conf: SparkConf, - kvstore: KVStore, + kvstore: ElementTrackingStore, live: Boolean, ui: Option[SparkUI] = None) extends SparkListener with Logging { @@ -51,6 +52,23 @@ private[sql] class SQLAppStatusListener( private var uiInitialized = false + kvstore.addTrigger(classOf[SQLExecutionUIData], conf.get(UI_RETAINED_EXECUTIONS)) { count => + cleanupExecutions(count) + } + + kvstore.onFlush { + if (!live) { + val now = System.nanoTime() + liveExecutions.values.asScala.foreach { exec => + // This saves the partial aggregated metrics to the store; this works currently because + // when the SHS sees an updated event log, all old data for the application is thrown + // away. + exec.metricsValues = aggregateMetrics(exec) + exec.write(kvstore, now) + } + } + } + override def onJobStart(event: SparkListenerJobStart): Unit = { val executionIdString = event.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) if (executionIdString == null) { @@ -317,6 +335,17 @@ private[sql] class SQLAppStatusListener( } } + private def cleanupExecutions(count: Long): Unit = { + val countToDelete = count - conf.get(UI_RETAINED_EXECUTIONS) + if (countToDelete <= 0) { + return + } + + val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]), + countToDelete.toInt) { e => e.completionTime.isDefined } + toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + } + } private class LiveExecutionData(val executionId: Long) extends LiveEntity { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 586d3ae411c74..7fd5f7395cdf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -27,7 +27,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.scheduler.SparkListener -import org.apache.spark.status.AppStatusPlugin +import org.apache.spark.status.{AppStatusPlugin, ElementTrackingStore} import org.apache.spark.status.KVUtils.KVIndexParam import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils @@ -84,7 +84,7 @@ private[sql] class SQLAppStatusPlugin extends AppStatusPlugin { override def setupListeners( conf: SparkConf, - store: KVStore, + store: ElementTrackingStore, addListenerFn: SparkListener => Unit, live: Boolean): Unit = { // For live applications, the listener is installed in [[setupUI]]. This also avoids adding @@ -100,7 +100,8 @@ private[sql] class SQLAppStatusPlugin extends AppStatusPlugin { case Some(sc) => // If this is a live application, then install a listener that will enable the SQL // tab as soon as there's a SQL event posted to the bus. - val listener = new SQLAppStatusListener(sc.conf, ui.store.store, true, Some(ui)) + val listener = new SQLAppStatusListener(sc.conf, + ui.store.store.asInstanceOf[ElementTrackingStore], true, Some(ui)) sc.listenerBus.addToStatusQueue(listener) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index eba8d55daad58..932950687942c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.status.ElementTrackingStore import org.apache.spark.status.config._ import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} import org.apache.spark.util.kvstore.InMemoryStore @@ -43,7 +44,9 @@ import org.apache.spark.util.kvstore.InMemoryStore class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ - override protected def sparkConf = super.sparkConf.set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + override protected def sparkConf = { + super.sparkConf.set(LIVE_ENTITY_UPDATE_PERIOD, 0L).set(ASYNC_TRACKING_ENABLED, false) + } private def createTestDataFrame: DataFrame = { Seq( @@ -107,10 +110,12 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest private def sqlStoreTest(name: String) (fn: (SQLAppStatusStore, SparkListenerBus) => Unit): Unit = { test(name) { - val store = new InMemoryStore() + val conf = sparkConf + val store = new ElementTrackingStore(new InMemoryStore(), conf) val bus = new ReplayListenerBus() - val listener = new SQLAppStatusListener(sparkConf, store, true) + val listener = new SQLAppStatusListener(conf, store, true) bus.addListener(listener) + store.close(false) val sqlStore = new SQLAppStatusStore(store, Some(listener)) fn(sqlStore, bus) } @@ -491,15 +496,15 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe class SQLListenerMemoryLeakSuite extends SparkFunSuite { - // TODO: this feature is not yet available in SQLAppStatusStore. - ignore("no memory leak") { - quietly { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly - .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - withSpark(new SparkContext(conf)) { sc => + test("no memory leak") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + .set(ASYNC_TRACKING_ENABLED, false) + withSpark(new SparkContext(conf)) { sc => + quietly { val spark = new SparkSession(sc) import spark.implicits._ // Run 100 successful executions and 100 failed executions. From fbfa9be7e0df0b4489571422c45d0d64d05d3050 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 19 Dec 2017 07:30:29 +0900 Subject: [PATCH 446/712] Revert "Revert "[SPARK-22496][SQL] thrift server adds operation logs"" This reverts commit e58f275678fb4f904124a4a2a1762f04c835eb0e. --- .../cli/operation/ExecuteStatementOperation.java | 13 +++++++++++++ .../hive/service/cli/operation/SQLOperation.java | 12 ------------ .../SparkExecuteStatementOperation.scala | 1 + 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java index 3f2de108f069a..6740d3bb59dc3 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ExecuteStatementOperation.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessor; import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory; +import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationType; import org.apache.hive.service.cli.session.HiveSession; @@ -67,4 +68,16 @@ protected void setConfOverlay(Map confOverlay) { this.confOverlay = confOverlay; } } + + protected void registerCurrentOperationLog() { + if (isOperationLogEnabled) { + if (operationLog == null) { + LOG.warn("Failed to get current OperationLog object of Operation: " + + getHandle().getHandleIdentifier()); + isOperationLogEnabled = false; + return; + } + OperationLog.setCurrentOperationLog(operationLog); + } + } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index 5014cedd870b6..fd9108eb53ca9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -274,18 +274,6 @@ private Hive getSessionHive() throws HiveSQLException { } } - private void registerCurrentOperationLog() { - if (isOperationLogEnabled) { - if (operationLog == null) { - LOG.warn("Failed to get current OperationLog object of Operation: " + - getHandle().getHandleIdentifier()); - isOperationLogEnabled = false; - return; - } - OperationLog.setCurrentOperationLog(operationLog); - } - } - private void cleanup(OperationState state) throws HiveSQLException { setState(state); if (shouldRunAsync()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f5191fa9132bd..664bc20601eaa 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -170,6 +170,7 @@ private[hive] class SparkExecuteStatementOperation( override def run(): Unit = { val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { + registerCurrentOperationLog() try { execute() } catch { From 3a07eff5af601511e97a05e6fea0e3d48f74c4f0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 19 Dec 2017 07:35:03 +0900 Subject: [PATCH 447/712] [SPARK-22813][BUILD] Use lsof or /usr/sbin/lsof in run-tests.py ## What changes were proposed in this pull request? In [the environment where `/usr/sbin/lsof` does not exist](https://github.com/apache/spark/pull/19695#issuecomment-342865001), `./dev/run-tests.py` for `maven` causes the following error. This is because the current `./dev/run-tests.py` checks existence of only `/usr/sbin/lsof` and aborts immediately if it does not exist. This PR changes to check whether `lsof` or `/usr/sbin/lsof` exists. ``` /bin/sh: 1: /usr/sbin/lsof: not found Usage: kill [options] [...] Options: [...] send signal to every listed -, -s, --signal specify the to be sent -l, --list=[] list all signal names, or convert one to a name -L, --table list all signal names in a nice table -h, --help display this help and exit -V, --version output version information and exit For more details see kill(1). Traceback (most recent call last): File "./dev/run-tests.py", line 626, in main() File "./dev/run-tests.py", line 597, in main build_apache_spark(build_tool, hadoop_version) File "./dev/run-tests.py", line 389, in build_apache_spark build_spark_maven(hadoop_version) File "./dev/run-tests.py", line 329, in build_spark_maven exec_maven(profiles_and_goals) File "./dev/run-tests.py", line 270, in exec_maven kill_zinc_on_port(zinc_port) File "./dev/run-tests.py", line 258, in kill_zinc_on_port subprocess.check_call(cmd, shell=True) File "/usr/lib/python2.7/subprocess.py", line 541, in check_call raise CalledProcessError(retcode, cmd) subprocess.CalledProcessError: Command '/usr/sbin/lsof -P |grep 3156 | grep LISTEN | awk '{ print $2; }' | xargs kill' returned non-zero exit status 123 ``` ## How was this patch tested? manually tested Author: Kazuaki Ishizaki Closes #19998 from kiszk/SPARK-22813. --- dev/run-tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index ef0e788a91606..7e6f7ff060351 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -253,9 +253,9 @@ def kill_zinc_on_port(zinc_port): """ Kill the Zinc process running on the given port, if one exists. """ - cmd = ("/usr/sbin/lsof -P |grep %s | grep LISTEN " - "| awk '{ print $2; }' | xargs kill") % zinc_port - subprocess.check_call(cmd, shell=True) + cmd = "%s -P |grep %s | grep LISTEN | awk '{ print $2; }' | xargs kill" + lsof_exe = which("lsof") + subprocess.check_call(cmd % (lsof_exe if lsof_exe else "/usr/sbin/lsof", zinc_port), shell=True) def exec_maven(mvn_args=()): From 0609dcc0382227f65e8a3340e33db438f7f1e6e7 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 18 Dec 2017 15:31:47 -0800 Subject: [PATCH 448/712] [SPARK-22777][SCHEDULER] Kubernetes mode dockerfile permission and distribution # What changes were proposed in this pull request? 1. entrypoint.sh for Kubernetes spark-base image is marked as executable (644 -> 755) 2. make-distribution script will now create kubernetes/dockerfiles directory when Kubernetes support is compiled. ## How was this patch tested? Manual testing cc/ ueshin jiangxb1987 mridulm vanzin rxin liyinan926 Author: foxish Closes #20007 from foxish/fix-dockerfiles. --- dev/make-distribution.sh | 8 +++++++- .../docker/src/main/dockerfiles/spark-base/entrypoint.sh | 0 2 files changed, 7 insertions(+), 1 deletion(-) mode change 100644 => 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 48a824499acb9..7245163ea2a51 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -168,12 +168,18 @@ echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" -# Only create the yarn directory if the yarn artifacts were build. +# Only create the yarn directory if the yarn artifacts were built. if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then mkdir "$DISTDIR/yarn" cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" fi +# Only create and copy the dockerfiles directory if the kubernetes artifacts were built. +if [ -d "$SPARK_HOME"/resource-managers/kubernetes/core/target/ ]; then + mkdir -p "$DISTDIR/kubernetes/" + cp -a "$SPARK_HOME"/resource-managers/kubernetes/docker/src/main/dockerfiles "$DISTDIR/kubernetes/" +fi + # Copy examples and dependencies mkdir -p "$DISTDIR/examples/jars" cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh old mode 100644 new mode 100755 From d4e69595dd17e4c31757410f91274122497dddbd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Dec 2017 09:48:31 +0800 Subject: [PATCH 449/712] [MINOR][SQL] Remove Useless zipWithIndex from ResolveAliases ## What changes were proposed in this pull request? Remove useless `zipWithIndex` from `ResolveAliases `. ## How was this patch tested? The existing tests Author: gatorsmile Closes #20009 from gatorsmile/try22. --- .../sql/catalyst/analysis/Analyzer.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0d5e866c0683e..10b237fb22b96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -222,22 +222,20 @@ class Analyzer( */ object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { - exprs.zipWithIndex.map { - case (expr, i) => - expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => - child match { - case ne: NamedExpression => ne - case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) - case e if !e.resolved => u - case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() - case e: ExtractValue => Alias(e, toPrettySQL(e))() - case e if optGenAliasFunc.isDefined => - Alias(child, optGenAliasFunc.get.apply(e))() - case e => Alias(e, toPrettySQL(e))() - } + exprs.map(_.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => + child match { + case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) + case e if !e.resolved => u + case g: Generator => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() + case e: ExtractValue => Alias(e, toPrettySQL(e))() + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case e => Alias(e, toPrettySQL(e))() } - }.asInstanceOf[Seq[NamedExpression]] + } + ).asInstanceOf[Seq[NamedExpression]] } private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = From ab7346f20c515c5af63b9400a9fe6e3a72138ef3 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 19 Dec 2017 21:51:56 +0800 Subject: [PATCH 450/712] [SPARK-22673][SQL] InMemoryRelation should utilize existing stats whenever possible ## What changes were proposed in this pull request? The current implementation of InMemoryRelation always uses the most expensive execution plan when writing cache With CBO enabled, we can actually have a more exact estimation of the underlying table size... ## How was this patch tested? existing test Author: CodingCat Author: Nan Zhu Author: Nan Zhu Closes #19864 from CodingCat/SPARK-22673. --- .../spark/sql/execution/CacheManager.scala | 20 +++--- .../execution/columnar/InMemoryRelation.scala | 21 ++++--- .../columnar/InMemoryColumnarQuerySuite.scala | 61 ++++++++++++++++--- 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 5a1d680c99f66..b05fe49a6ac3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -94,14 +94,13 @@ class CacheManager extends Logging { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData.add(CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName))) + val inMemoryRelation = InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName, + planToCache.stats) + cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +147,8 @@ class CacheManager extends Logging { batchSize = cd.cachedRepresentation.batchSize, storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName) + tableName = cd.cachedRepresentation.tableName, + statsOfPlanToCache = cd.plan.stats) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a1c62a729900e..51928d914841e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -37,8 +37,10 @@ object InMemoryRelation { batchSize: Int, storageLevel: StorageLevel, child: SparkPlan, - tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + tableName: Option[String], + statsOfPlanToCache: Statistics): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( + statsOfPlanToCache = statsOfPlanToCache) } @@ -60,7 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -71,9 +74,8 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache + statsOfPlanToCache } else { Statistics(sizeInBytes = batchStats.value.longValue) } @@ -142,7 +144,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats) + _cachedColumnBuffers, batchStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -154,11 +156,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats).asInstanceOf[this.type] + batchStats, + statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats) + Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index e662e294228db..ff7c5e58e9863 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -40,7 +41,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { data.createOrReplaceTempView(s"testData$dataType") val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None) + val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, + data.logicalPlan.stats) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -116,7 +118,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + testData.logicalPlan.stats) checkAnswer(scan, testData.collect().toSeq) } @@ -132,8 +135,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = spark.sessionState.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val logicalPlan = testData.select('value, 'key).logicalPlan + val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + logicalPlan.stats) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -149,7 +154,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + testData.logicalPlan.stats) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -323,7 +329,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) // Materialize the data. val expectedAnswer = data.collect() @@ -448,8 +454,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, - LocalTableScanExec(Seq(attribute), Nil), None) + val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) @@ -479,4 +485,43 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { + withSQLConf("spark.sql.cbo.enabled" -> "true") { + withTempPath { workDir => + withTable("table1") { + val workDirPath = workDir.getAbsolutePath + val data = Seq(100, 200, 300, 400).toDF("count") + data.write.parquet(workDirPath) + val dfFromFile = spark.read.parquet(workDirPath).cache() + val inMemoryRelation = dfFromFile.queryExecution.optimizedPlan.collect { + case plan: InMemoryRelation => plan + }.head + // InMemoryRelation's stats is file size before the underlying RDD is materialized + assert(inMemoryRelation.computeStats().sizeInBytes === 740) + + // InMemoryRelation's stats is updated after materializing RDD + dfFromFile.collect() + assert(inMemoryRelation.computeStats().sizeInBytes === 16) + + // test of catalog table + val dfFromTable = spark.catalog.createTable("table1", workDirPath).cache() + val inMemoryRelation2 = dfFromTable.queryExecution.optimizedPlan. + collect { case plan: InMemoryRelation => plan }.head + + // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats + // is calculated + assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + + // InMemoryRelation's stats should be updated after calculating stats of the table + // clear cache to simulate a fresh environment + dfFromTable.unpersist(blocking = true) + spark.sql("ANALYZE TABLE table1 COMPUTE STATISTICS") + val inMemoryRelation3 = spark.read.table("table1").cache().queryExecution.optimizedPlan. + collect { case plan: InMemoryRelation => plan }.head + assert(inMemoryRelation3.computeStats().sizeInBytes === 48) + } + } + } + } } From 571aa275541d71dbef8f0c86eab4ef04d56e4394 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Tue, 19 Dec 2017 21:55:21 +0800 Subject: [PATCH 451/712] [SPARK-21984][SQL] Join estimation based on equi-height histogram ## What changes were proposed in this pull request? Equi-height histogram is one of the state-of-the-art statistics for cardinality estimation, which can provide better estimation accuracy, and good at cases with skew data. This PR is to improve join estimation based on equi-height histogram. The difference from basic estimation (based on ndv) is the logic for computing join cardinality and the new ndv after join. The main idea is as follows: 1. find overlapped ranges between two histograms from two join keys; 2. apply the formula `T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1))` in each overlapped range. ## How was this patch tested? Added new test cases. Author: Zhenhua Wang Closes #19594 from wzhfy/join_estimation_histogram. --- .../statsEstimation/EstimationUtils.scala | 169 ++++++++++++++ .../statsEstimation/JoinEstimation.scala | 54 ++++- .../statsEstimation/JoinEstimationSuite.scala | 209 +++++++++++++++++- 3 files changed, 428 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 6f868cbd072c8..71e852afe0659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -212,4 +213,172 @@ object EstimationUtils { } } + /** + * Returns overlapped ranges between two histograms, in the given value range + * [lowerBound, upperBound]. + */ + def getOverlappedRanges( + leftHistogram: Histogram, + rightHistogram: Histogram, + lowerBound: Double, + upperBound: Double): Seq[OverlappedRange] = { + val overlappedRanges = new ArrayBuffer[OverlappedRange]() + // Only bins whose range intersect [lowerBound, upperBound] have join possibility. + val leftBins = leftHistogram.bins + .filter(b => b.lo <= upperBound && b.hi >= lowerBound) + val rightBins = rightHistogram.bins + .filter(b => b.lo <= upperBound && b.hi >= lowerBound) + + leftBins.foreach { lb => + rightBins.foreach { rb => + val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound) + val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound) + // Only collect overlapped ranges. + if (left.lo <= right.hi && left.hi >= right.lo) { + // Collect overlapped ranges. + val range = if (right.lo >= left.lo && right.hi >= left.hi) { + // Case1: the left bin is "smaller" than the right bin + // left.lo right.lo left.hi right.hi + // --------+------------------+------------+----------------+-------> + if (left.hi == right.lo) { + // The overlapped range has only one value. + OverlappedRange( + lo = right.lo, + hi = right.lo, + leftNdv = 1, + rightNdv = 1, + leftNumRows = leftHeight / left.ndv, + rightNumRows = rightHeight / right.ndv + ) + } else { + val leftRatio = (left.hi - right.lo) / (left.hi - left.lo) + val rightRatio = (left.hi - right.lo) / (right.hi - right.lo) + OverlappedRange( + lo = right.lo, + hi = left.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight * rightRatio + ) + } + } else if (right.lo <= left.lo && right.hi <= left.hi) { + // Case2: the left bin is "larger" than the right bin + // right.lo left.lo right.hi left.hi + // --------+------------------+------------+----------------+-------> + if (right.hi == left.lo) { + // The overlapped range has only one value. + OverlappedRange( + lo = right.hi, + hi = right.hi, + leftNdv = 1, + rightNdv = 1, + leftNumRows = leftHeight / left.ndv, + rightNumRows = rightHeight / right.ndv + ) + } else { + val leftRatio = (right.hi - left.lo) / (left.hi - left.lo) + val rightRatio = (right.hi - left.lo) / (right.hi - right.lo) + OverlappedRange( + lo = left.lo, + hi = right.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight * rightRatio + ) + } + } else if (right.lo >= left.lo && right.hi <= left.hi) { + // Case3: the left bin contains the right bin + // left.lo right.lo right.hi left.hi + // --------+------------------+------------+----------------+-------> + val leftRatio = (right.hi - right.lo) / (left.hi - left.lo) + OverlappedRange( + lo = right.lo, + hi = right.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight + ) + } else { + assert(right.lo <= left.lo && right.hi >= left.hi) + // Case4: the right bin contains the left bin + // right.lo left.lo left.hi right.hi + // --------+------------------+------------+----------------+-------> + val rightRatio = (left.hi - left.lo) / (right.hi - right.lo) + OverlappedRange( + lo = left.lo, + hi = left.hi, + leftNdv = left.ndv, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight, + rightNumRows = rightHeight * rightRatio + ) + } + overlappedRanges += range + } + } + } + overlappedRanges + } + + /** + * Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part + * of the bin in that range and its number of rows. + * @param bin the input histogram bin. + * @param height the number of rows of the given histogram bin inside an equi-height histogram. + * @param lowerBound lower bound of the given range. + * @param upperBound upper bound of the given range. + * @return trimmed part of the given bin and its number of rows. + */ + def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double) + : (HistogramBin, Double) = { + val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) { + // bin.lo lowerBound upperBound bin.hi + // --------+------------------+------------+-------------+-------> + (lowerBound, upperBound) + } else if (bin.lo <= lowerBound && bin.hi >= lowerBound) { + // bin.lo lowerBound bin.hi upperBound + // --------+------------------+------------+-------------+-------> + (lowerBound, bin.hi) + } else if (bin.lo <= upperBound && bin.hi >= upperBound) { + // lowerBound bin.lo upperBound bin.hi + // --------+------------------+------------+-------------+-------> + (bin.lo, upperBound) + } else { + // lowerBound bin.lo bin.hi upperBound + // --------+------------------+------------+-------------+-------> + assert(bin.lo >= lowerBound && bin.hi <= upperBound) + (bin.lo, bin.hi) + } + + if (hi == lo) { + // Note that bin.hi == bin.lo also falls into this branch. + (HistogramBin(lo, hi, 1), height / bin.ndv) + } else { + assert(bin.hi != bin.lo) + val ratio = (hi - lo) / (bin.hi - bin.lo) + (HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio) + } + } + + /** + * A join between two equi-height histograms may produce multiple overlapped ranges. + * Each overlapped range is produced by a part of one bin in the left histogram and a part of + * one bin in the right histogram. + * @param lo lower bound of this overlapped range. + * @param hi higher bound of this overlapped range. + * @param leftNdv ndv in the left part. + * @param rightNdv ndv in the right part. + * @param leftNumRows number of rows in the left part. + * @param rightNumRows number of rows in the right part. + */ + case class OverlappedRange( + lo: Double, + hi: Double, + leftNdv: Double, + rightNdv: Double, + leftNumRows: Double, + rightNumRows: Double) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index b073108c26ee5..f0294a4246703 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging { val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) - keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) + val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match { + case (Some(l: Histogram), Some(r: Histogram)) => + computeByHistogram(leftKey, rightKey, l, r, newMin, newMax) + case _ => + computeByNdv(leftKey, rightKey, newMin, newMax) + } + keyStatsAfterJoin += ( + // Histograms are propagated as unchanged. During future estimation, they should be + // truncated by the updated max/min. In this way, only pointers of the histograms are + // propagated and thus reduce memory consumption. + leftKey -> joinStat.copy(histogram = leftKeyStat.histogram), + rightKey -> joinStat.copy(histogram = rightKeyStat.histogram) + ) // Return cardinality estimated from the most selective join keys. if (card < joinCard) joinCard = card } else { @@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging { (ceil(card), newStats) } + /** Compute join cardinality using equi-height histograms. */ + private def computeByHistogram( + leftKey: AttributeReference, + rightKey: AttributeReference, + leftHistogram: Histogram, + rightHistogram: Histogram, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, ColumnStat) = { + val overlappedRanges = getOverlappedRanges( + leftHistogram = leftHistogram, + rightHistogram = rightHistogram, + // Only numeric values have equi-height histograms. + lowerBound = newMin.get.toString.toDouble, + upperBound = newMax.get.toString.toDouble) + + var card: BigDecimal = 0 + var totalNdv: Double = 0 + for (i <- overlappedRanges.indices) { + val range = overlappedRanges(i) + if (i == 0 || range.hi != overlappedRanges(i - 1).hi) { + // If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one + // value, and this value is already counted in the previous range. So there is no need to + // count it in this range. + totalNdv += math.min(range.leftNdv, range.rightNdv) + } + // Apply the formula in this overlapped range. + card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv) + } + + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + (ceil(card), newStats) + } + /** * Propagate or update column stats for output attributes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 097c78eb27fca..26139d85d25fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -67,6 +67,213 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = 2, attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo))) + private def estimateByHistogram( + leftHistogram: Histogram, + rightHistogram: Histogram, + expectedMin: Double, + expectedMax: Double, + expectedNdv: Long, + expectedRows: Long): Unit = { + val col1 = attr("key1") + val col2 = attr("key2") + val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax) + val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax) + + val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2))) + val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1))) + val expectedStatsAfterJoin = Statistics( + sizeInBytes = expectedRows * (8 + 2 * 4), + rowCount = Some(expectedRows), + attributeStats = AttributeMap(Seq( + col1 -> c1.stats.attributeStats(col1).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + col2 -> c2.stats.attributeStats(col2).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + ) + + // Join order should not affect estimation result. + Seq(c1JoinC2, c2JoinC1).foreach { join => + assert(join.stats == expectedStatsAfterJoin) + } + } + + private def generateJoinChild( + col: Attribute, + histogram: Histogram, + expectedMin: Double, + expectedMax: Double): LogicalPlan = { + val colStat = inferColumnStat(histogram) + StatsTestPlan( + outputList = Seq(col), + rowCount = (histogram.height * histogram.bins.length).toLong, + attributeStats = AttributeMap(Seq(col -> colStat))) + } + + /** Column statistics should be consistent with histograms in tests. */ + private def inferColumnStat(histogram: Histogram): ColumnStat = { + var ndv = 0L + for (i <- histogram.bins.indices) { + val bin = histogram.bins(i) + if (i == 0 || bin.hi != histogram.bins(i - 1).hi) { + ndv += bin.ndv + } + } + ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), + max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + histogram = Some(histogram)) + } + + test("equi-height histograms: a bin is contained by another one") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 10, upperBound = 60) + assert(t0 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h0 == 80) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 10, upperBound = 60) + assert(t1 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20) + + val expectedRanges = Seq( + // histogram1.bins(0) overlaps t0 + OverlappedRange(10, 30, 10, 40 * 1 / 2, 300, 80 * 1 / 2), + // histogram1.bins(1) overlaps t0 + OverlappedRange(30, 50, 30 * 2 / 3, 40 * 1 / 2, 300 * 2 / 3, 80 * 1 / 2), + // histogram1.bins(1) overlaps t1 + OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 10, upperBound = 60))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 10, + expectedMax = 60, + expectedNdv = 10 + 20 + 8, + expectedRows = 300 * 40 / 20 + 200 * 40 / 20 + 100 * 20 / 10) + } + + test("equi-height histograms: a bin has only one value after trimming") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 75, ndv = 3))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 50, upperBound = 75) + assert(t0 == HistogramBin(lo = 50, hi = 50, ndv = 1) && h0 == 2) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 50, upperBound = 75) + assert(t1 == HistogramBin(lo = 50, hi = 75, ndv = 20) && h1 == 50) + + val expectedRanges = Seq( + // histogram1.bins(0) overlaps t0 + OverlappedRange(50, 50, 1, 1, 300 / 10, 2), + // histogram1.bins(0) overlaps t1 + OverlappedRange(50, 60, 10, 20 * 10 / 25, 300, 50 * 10 / 25), + // histogram1.bins(1) overlaps t1 + OverlappedRange(60, 75, 3, 20 * 15 / 25, 300, 50 * 15 / 25) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 50, upperBound = 75))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 50, + expectedMax = 75, + expectedNdv = 1 + 8 + 3, + expectedRows = 30 * 2 / 1 + 300 * 20 / 10 + 300 * 30 / 12) + } + + test("equi-height histograms: skew distribution (some bins have only one value)") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), + HistogramBin(lo = 30, hi = 30, ndv = 1), + HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 60) + assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 40) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 30, upperBound = 60) + assert(t1 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 50, 30 * 2 / 3, 20, 300 * 2 / 3, 40), + OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 60))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 60, + expectedNdv = 1 + 20 + 8, + expectedRows = 300 * 2 / 1 + 300 * 2 / 1 + 200 * 40 / 20 + 100 * 20 / 10) + } + + test("equi-height histograms: skew distribution (histograms have different skewed values") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 50, ndv = 1))) + // test bin trimming + val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 50) + assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 200) + val (t1, h1) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 50) + assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 50, 20, 20, 200, 40), + OverlappedRange(50, 50, 1, 1, 200 / 20, 100) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 50))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 50, + expectedNdv = 1 + 20, + expectedRows = 300 * 2 / 1 + 200 * 40 / 20 + 10 * 100 / 1) + } + + test("equi-height histograms: skew distribution (both histograms have the same skewed value") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 150, Array( + HistogramBin(lo = 0, hi = 30, ndv = 30), HistogramBin(lo = 30, hi = 30, ndv = 1))) + // test bin trimming + val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 30) + assert(t0 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h0 == 10) + val (t1, h1) = trimBin(histogram2.bins(0), height = 150, lowerBound = 30, upperBound = 30) + assert(t1 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h1 == 5) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 5), + OverlappedRange(30, 30, 1, 1, 300, 150), + OverlappedRange(30, 30, 1, 1, 10, 5), + OverlappedRange(30, 30, 1, 1, 10, 150) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 30))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 30, + // only one value: 30 + expectedNdv = 1, + expectedRows = 300 * 5 / 1 + 300 * 150 / 1 + 10 * 5 / 1 + 10 * 150 / 1) + } + test("cross join") { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) From 28315714ddef3ddcc192375e98dd5207cf4ecc98 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Dec 2017 22:12:23 +0800 Subject: [PATCH 452/712] [SPARK-22791][SQL][SS] Redact Output of Explain ## What changes were proposed in this pull request? When calling explain on a query, the output can contain sensitive information. We should provide an admin/user to redact such information. Before this PR, the plan of SS is like this ``` == Physical Plan == *HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#12L]) +- StateStoreSave [value#6], state info [ checkpoint = file:/private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-91c6fac0-609f-4bc8-ad57-52c189f06797/state, runId = 05a4b3af-f02c-40f8-9ff9-a3e18bae496f, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#18L]) +- StateStoreRestore [value#6], state info [ checkpoint = file:/private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-91c6fac0-609f-4bc8-ad57-52c189f06797/state, runId = 05a4b3af-f02c-40f8-9ff9-a3e18bae496f, opId = 0, ver = 0, numPartitions = 5] +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#18L]) +- Exchange hashpartitioning(value#6, 5) +- *HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#18L]) +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *MapElements , obj#5: java.lang.String +- *DeserializeToObject value#30.toString, obj#4: java.lang.String +- LocalTableScan [value#30] ``` After this PR, we can get the following output if users set `spark.redaction.string.regex` to `file:/[\\w_]+` ``` == Physical Plan == *HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#12L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-e7da9b7d-3ec0-474d-8b8c-927f7d12ed72/state, runId = 8a9c3761-93d5-4896-ab82-14c06240dcea, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#32L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-e7da9b7d-3ec0-474d-8b8c-927f7d12ed72/state, runId = 8a9c3761-93d5-4896-ab82-14c06240dcea, opId = 0, ver = 0, numPartitions = 5] +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#32L]) +- Exchange hashpartitioning(value#6, 5) +- *HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#32L]) +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *MapElements , obj#5: java.lang.String +- *DeserializeToObject value#27.toString, obj#4: java.lang.String +- LocalTableScan [value#27] ``` ## How was this patch tested? Added a test case Author: gatorsmile Closes #19985 from gatorsmile/redactPlan. --- .../scala/org/apache/spark/util/Utils.scala | 26 +++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 11 +++++++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 13 ++++++-- .../DataSourceScanExecRedactionSuite.scala | 31 +++++++++++++++++ .../spark/sql/streaming/StreamSuite.scala | 33 ++++++++++++++++++- 6 files changed, 105 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8871870eb8681..5853302973140 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2650,15 +2650,29 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs) } + /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ + def redact(regex: Option[Regex], kvs: Seq[(String, String)]): Seq[(String, String)] = { + regex match { + case None => kvs + case Some(r) => redact(r, kvs) + } + } + /** * Redact the sensitive information in the given string. */ - def redact(conf: SparkConf, text: String): String = { - if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) { - text - } else { - val regex = conf.get(STRING_REDACTION_PATTERN).get - regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + def redact(regex: Option[Regex], text: String): String = { + regex match { + case None => text + case Some(r) => + if (text == null || text.isEmpty) { + text + } else { + r.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cf7e3ebce7411..bdc8d92e84079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable +import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -1035,6 +1036,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = + ConfigBuilder("spark.sql.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands. " + + "When this conf is not set, the value from `spark.redaction.string.regex` is used.") + .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1173,6 +1182,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 747749bc72e66..27c7dc3ee4533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text) + Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 946475a1e9751..8bfe3eff0c3b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -194,13 +194,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } } - def simpleString: String = { + def simpleString: String = withRedaction { s"""== Physical Plan == |${stringOrError(executedPlan.treeString(verbose = false))} """.stripMargin.trim } - override def toString: String = { + override def toString: String = withRedaction { def output = Utils.truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") val analyzedPlan = Seq( @@ -219,7 +219,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } - def stringWithStats: String = { + def stringWithStats: String = withRedaction { // trigger to compute stats for logical plans optimizedPlan.stats @@ -231,6 +231,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } + /** + * Redact the sensitive information in the given string. + */ + private def withRedaction(message: String): String = { + Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 423e1288e8dcb..c8d045a32d73c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -20,6 +20,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext /** @@ -52,4 +53,34 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.simpleString.contains(replacement)) } } + + private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = { + queryExecution.toString.contains(msg) || + queryExecution.simpleString.contains(msg) || + queryExecution.stringWithStats.contains(msg) + } + + test("explain is redacted using SQLConf") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + val df = spark.read.parquet(basePath) + val replacement = "*********" + + // Respect SparkConf and replace file:/ + assert(isIncluded(df.queryExecution, replacement)) + + assert(isIncluded(df.queryExecution, "FileScan")) + assert(!isIncluded(df.queryExecution, "file:/")) + + withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") { + // Respect SQLConf and replace FileScan + assert(isIncluded(df.queryExecution, replacement)) + + assert(!isIncluded(df.queryExecution, "FileScan")) + assert(isIncluded(df.queryExecution, "file:/")) + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9e696b2236b68..fa4b2dd6a6c9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -28,7 +28,7 @@ import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range @@ -418,6 +418,37 @@ class StreamSuite extends StreamTest { assert(OutputMode.Update === InternalOutputModes.Update) } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") + + test("explain - redaction") { + val replacement = "*********" + + val inputData = MemoryStream[String] + val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*")) + // Test StreamingQuery.display + val q = df.writeStream.queryName("memory_explain").outputMode("complete").format("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + try { + inputData.addData("abc") + q.processAllAvailable() + + val explainWithoutExtended = q.explainInternal(false) + assert(explainWithoutExtended.contains(replacement)) + assert(explainWithoutExtended.contains("StateStoreRestore")) + assert(!explainWithoutExtended.contains("file:/")) + + val explainWithExtended = q.explainInternal(true) + assert(explainWithExtended.contains(replacement)) + assert(explainWithExtended.contains("StateStoreRestore")) + assert(!explainWithoutExtended.contains("file:/")) + } finally { + q.stop() + } + } + test("explain") { val inputData = MemoryStream[String] val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*")) From b779c93518bd46850d6576bd34ea11f78fc4e01a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Dec 2017 22:17:04 +0800 Subject: [PATCH 453/712] [SPARK-22815][SQL] Keep PromotePrecision in Optimized Plans ## What changes were proposed in this pull request? We could get incorrect results by running DecimalPrecision twice. This PR resolves the original found in https://github.com/apache/spark/pull/15048 and https://github.com/apache/spark/pull/14797. After this PR, it becomes easier to change it back using `children` instead of using `innerChildren`. ## How was this patch tested? The existing test. Author: gatorsmile Closes #20000 from gatorsmile/keepPromotePrecision. --- .../spark/sql/catalyst/analysis/StreamingJoinHelper.scala | 4 +++- .../spark/sql/catalyst/expressions/decimalExpressions.scala | 2 ++ .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 1 - 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 072dc954879ca..7a0aa08289efa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ @@ -238,6 +238,8 @@ object StreamingJoinHelper extends PredicateHelper with Logging { collect(child, !negate) case CheckOverflow(child, _) => collect(child, negate) + case PromotePrecision(child) => + collect(child, negate) case Cast(child, dataType, _) => dataType match { case _: NumericType | _: TimestampType => collect(child, negate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 752dea23e1f7a..db1579ba28671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -70,10 +70,12 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) + /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def prettyName: String = "promote_precision" override def sql: String = child.sql + override lazy val canonicalized: Expression = child.canonicalized } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6305b6c84bae3..85295aff19808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -614,7 +614,6 @@ object SimplifyCasts extends Rule[LogicalPlan] { object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child - case PromotePrecision(child) => child } } From ee56fc3432e7e328c29b88a255f7e2c2a4739754 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 20 Dec 2017 00:10:54 +0800 Subject: [PATCH 454/712] [SPARK-18016][SQL] Code Generation: Constant Pool Limit - reduce entries for mutable state ## What changes were proposed in this pull request? This PR is follow-on of #19518. This PR tries to reduce the number of constant pool entries used for accessing mutable state. There are two directions: 1. Primitive type variables should be allocated at the outer class due to better performance. Otherwise, this PR allocates an array. 2. The length of allocated array is up to 32768 due to avoiding usage of constant pool entry at access (e.g. `mutableStateArray[32767]`). Here are some discussions to determine these directions. 1. [[1]](https://github.com/apache/spark/pull/19518#issuecomment-346690464), [[2]](https://github.com/apache/spark/pull/19518#issuecomment-346690642), [[3]](https://github.com/apache/spark/pull/19518#issuecomment-346828180), [[4]](https://github.com/apache/spark/pull/19518#issuecomment-346831544), [[5]](https://github.com/apache/spark/pull/19518#issuecomment-346857340) 2. [[6]](https://github.com/apache/spark/pull/19518#issuecomment-346729172), [[7]](https://github.com/apache/spark/pull/19518#issuecomment-346798358), [[8]](https://github.com/apache/spark/pull/19518#issuecomment-346870408) This PR modifies `addMutableState` function in the `CodeGenerator` to check if the declared state can be easily initialized compacted into an array. We identify three types of states that cannot compacted: - Primitive type state (ints, booleans, etc) if the number of them does not exceed threshold - Multiple-dimensional array type - `inline = true` When `useFreshName = false`, the given name is used. Many codes were ported from #19518. Many efforts were put here. I think this PR should credit to bdrillard With this PR, the following code is generated: ``` /* 005 */ class SpecificMutableProjection extends org.apache.spark.sql.catalyst.expressions.codegen.BaseMutableProjection { /* 006 */ /* 007 */ private Object[] references; /* 008 */ private InternalRow mutableRow; /* 009 */ private boolean isNull_0; /* 010 */ private boolean isNull_1; /* 011 */ private boolean isNull_2; /* 012 */ private int value_2; /* 013 */ private boolean isNull_3; ... /* 10006 */ private int value_4999; /* 10007 */ private boolean isNull_5000; /* 10008 */ private int value_5000; /* 10009 */ private InternalRow[] mutableStateArray = new InternalRow[2]; /* 10010 */ private boolean[] mutableStateArray1 = new boolean[7001]; /* 10011 */ private int[] mutableStateArray2 = new int[1001]; /* 10012 */ private UTF8String[] mutableStateArray3 = new UTF8String[6000]; /* 10013 */ ... /* 107956 */ private void init_176() { /* 107957 */ isNull_4986 = true; /* 107958 */ value_4986 = -1; ... /* 108004 */ } ... ``` ## How was this patch tested? Added a new test case to `GeneratedProjectionSuite` Author: Kazuaki Ishizaki Closes #19811 from kiszk/SPARK-18016. --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../MonotonicallyIncreasingID.scala | 6 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 157 +++++++++++++++--- .../codegen/GenerateMutableProjection.scala | 40 ++--- .../codegen/GenerateUnsafeProjection.scala | 19 +-- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 37 ++--- .../sql/catalyst/expressions/generators.scala | 6 +- .../expressions/nullExpressions.scala | 3 +- .../expressions/objects/objects.scala | 60 ++++--- .../expressions/randomExpressions.scala | 6 +- .../expressions/regexpExpressions.scala | 34 ++-- .../expressions/stringExpressions.scala | 23 +-- .../ArithmeticExpressionSuite.scala | 4 +- .../sql/catalyst/expressions/CastSuite.scala | 2 +- .../expressions/CodeGenerationSuite.scala | 29 +++- .../expressions/ComplexTypeSuite.scala | 2 +- .../ConditionalExpressionSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 4 +- .../expressions/RegexpExpressionsSuite.scala | 5 +- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 27 +++ .../optimizer/complexTypesSuite.scala | 2 +- .../sql/execution/ColumnarBatchScan.scala | 26 ++- .../sql/execution/DataSourceScanExec.scala | 6 +- .../apache/spark/sql/execution/SortExec.scala | 18 +- .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../aggregate/HashAggregateExec.scala | 44 +++-- .../aggregate/HashMapGenerator.scala | 6 +- .../execution/basicPhysicalOperators.scala | 71 ++++---- .../columnar/GenerateColumnAccessor.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 15 +- .../execution/joins/SortMergeJoinExec.scala | 21 ++- .../apache/spark/sql/execution/limit.scala | 6 +- 37 files changed, 404 insertions(+), 304 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e9..4568714933095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,8 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.freshName("globalIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 821d784a01342..784eaf8195194 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -65,10 +65,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.freshName("count") - val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm) - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm) + val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask") ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 4fa18d6b3209b..736ca37c6d54a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -43,8 +43,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm) + val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId") ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1893eec22b65d..d3a8cb5804717 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,8 +602,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("leastTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} @@ -683,8 +682,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.freshName("greatestTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3a03a65e1af92..41a920ba3d677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -137,22 +137,63 @@ class CodegenContext { var currentVars: Seq[ExprCode] = null /** - * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a - * 3-tuple: java type, variable name, code to init it. - * As an example, ("int", "count", "count = 0;") will produce code: + * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a + * 2-tuple: java type, variable name. + * As an example, ("int", "count") will produce code: * {{{ * private int count; * }}} - * as a member variable, and add - * {{{ - * count = 0; - * }}} - * to the constructor. + * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = + mutable.ArrayBuffer.empty[(String, String)] + + /** + * The mapping between mutable state types and corrseponding compacted arrays. + * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates + * the compacted arrays for the mutable states with the same java type. + */ + val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = + mutable.Map.empty[String, MutableStateArrays] + + // An array holds the code that will initialize each state + val mutableStateInitCode: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] + + /** + * This class holds a set of names of mutableStateArrays that is used for compacting mutable + * states for a certain type, and holds the next available slot of the current compacted array. + */ + class MutableStateArrays { + val arrayNames = mutable.ListBuffer.empty[String] + createNewArray() + + private[this] var currentIndex = 0 + + private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + + def getCurrentIndex: Int = currentIndex + + /** + * Returns the reference of next available slot in current compacted array. The size of each + * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * Once reaching the threshold, new compacted array is created. + */ + def getNextSlot(): String = { + if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + val res = s"${arrayNames.last}[$currentIndex]" + currentIndex += 1 + res + } else { + createNewArray() + currentIndex = 1 + s"${arrayNames.last}[0]" + } + } + + } /** * Add a mutable state as a field to the generated class. c.f. the comments above. @@ -163,11 +204,52 @@ class CodegenContext { * the list of default imports available. * Also, generic type arguments are accepted but ignored. * @param variableName Name of the field. - * @param initCode The statement(s) to put into the init() method to initialize this field. + * @param initFunc Function includes statement(s) to put into the init() method to initialize + * this field. The argument is the name of the mutable state variable. * If left blank, the field will be default-initialized. + * @param forceInline whether the declaration and initialization code may be inlined rather than + * compacted. Please set `true` into forceInline for one of the followings: + * 1. use the original name of the status + * 2. expect to non-frequently generate the status + * (e.g. not much sort operators in one stage) + * @param useFreshName If this is false and the mutable state ends up inlining in the outer + * class, the name is not changed + * @return the name of the mutable state variable, which is the original name or fresh name if + * the variable is inlined to the outer class, or an array access if the variable is to + * be stored in an array of variables of the same type. + * A variable will be inlined into the outer class when one of the following conditions + * are satisfied: + * 1. forceInline is true + * 2. its type is primitive type and the total number of the inlined mutable variables + * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * 3. its type is multi-dimensional array + * When a variable is compacted into an array, the max size of the array for compaction + * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. */ - def addMutableState(javaType: String, variableName: String, initCode: String = ""): Unit = { - mutableStates += ((javaType, variableName, initCode)) + def addMutableState( + javaType: String, + variableName: String, + initFunc: String => String = _ => "", + forceInline: Boolean = false, + useFreshName: Boolean = true): String = { + + // want to put a primitive type variable at outerClass for performance + val canInlinePrimitive = isPrimitiveType(javaType) && + (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { + val varName = if (useFreshName) freshName(variableName) else variableName + val initCode = initFunc(varName) + inlinedMutableStates += ((javaType, varName)) + mutableStateInitCode += initCode + varName + } else { + val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) + val element = arrays.getNextSlot() + + val initCode = initFunc(element) + mutableStateInitCode += initCode + element + } } /** @@ -176,8 +258,7 @@ class CodegenContext { * data types like: UTF8String, ArrayData, MapData & InternalRow. */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { - val value = freshName(variableName) - addMutableState(javaType(dataType), value, "") + val value = addMutableState(javaType(dataType), variableName) val code = dataType match { case StringType => s"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" @@ -189,15 +270,37 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = inlinedMutableStates.distinct.map { case (javaType, variableName) => s"private $javaType $variableName;" - }.mkString("\n") + } + + val arrayStates = arrayCompactedMutableStates.flatMap { case (javaType, mutableStateArrays) => + val numArrays = mutableStateArrays.arrayNames.size + mutableStateArrays.arrayNames.zipWithIndex.map { case (arrayName, index) => + val length = if (index + 1 == numArrays) { + mutableStateArrays.getCurrentIndex + } else { + CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + } + if (javaType.contains("[]")) { + // initializer had an one-dimensional array variable + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + // initializer had a scalar variable + s"private $javaType[] $arrayName = new $javaType[$length];" + } + } + } + + (inlinedStates ++ arrayStates).mkString("\n") } def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStates.distinct.map(_._3 + "\n") + val initCodes = mutableStateInitCode.distinct + // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) @@ -1011,9 +1114,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach { e => val expr = e.head - val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" - val value = s"${fnName}Value" + val fnName = freshName("subExpr") + val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val value = addMutableState(javaType(expr.dataType), "subExprValue") // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) @@ -1039,9 +1142,6 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) @@ -1165,6 +1265,15 @@ object CodeGenerator extends Logging { // class. val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + // This is the threshold for the number of global variables, whose types are primitive type or + // complex type (e.g. more than one-dimensional array), that will be placed at the outer class + val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + + // This is the maximum number of array elements to keep global variables in one Java array + // 32767 is the maximum integer value that does not require a constant pool entry in a Java + // bytecode instruction + val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index bd8312eb8b7fe..b53c0087e7e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -57,41 +57,37 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case _ => true }.unzip val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) - val projectionCodes = exprVals.zip(index).map { + + // 4-tuples: (code for projection, isNull variable name, value variable name, column index) + val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) + val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") if (e.nullable) { - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull, s"$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; - """ + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + (s""" + |${ev.code} + |$isNull = ${ev.isNull}; + |$value = ${ev.value}; + """.stripMargin, isNull, value, i) } else { - val value = s"value_$i" - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${ev.code} - $value = ${ev.value}; - """ + (s""" + |${ev.code} + |$value = ${ev.value}; + """.stripMargin, ev.isNull, value, i) } } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(index).map { - case (e, i) => - val ev = ExprCode("", s"isNull_$i", s"value_$i") + val updates = validExpr.zip(projectionCodes).map { + case (e, (_, isNull, value, i)) => + val ev = ExprCode("", isNull, value) ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) val codeBody = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b022457865d50..36ffa8dcdd2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -73,9 +73,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String, isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, - s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -186,9 +185,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.freshName("arrayWriter") - ctx.addMutableState(arrayWriterClass, arrayWriter, - s"$arrayWriter = new $arrayWriterClass();") + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -318,13 +316,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + val result = ctx.addMutableState("UnsafeRow", "result", + v => s"$v = new UnsafeRow(${expressions.length});") - val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass($result, ${numVarLenFields * 32});") + val holder = ctx.addMutableState(holderClass, "holder", + v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 53c3b226895ec..1a9b68222a7f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,8 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.freshName("caseWhenTmpResult") - ctx.addMutableState(ctx.javaType(dataType), tmpResult) + val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 44d54a20844a3..cfec7f82951a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -442,9 +442,9 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") + val c = ctx.addMutableState(cal, "cal", + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); ${ev.value} = $c.get($cal.DAY_OF_WEEK); @@ -484,18 +484,17 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, + val c = ctx.addMutableState(cal, "cal", v => s""" - $c = $cal.getInstance($dtu.getTimeZone("UTC")); - $c.setFirstDayOfWeek($cal.MONDAY); - $c.setMinimalDaysInFirstWeek(4); - """) + |$v = $cal.getInstance($dtu.getTimeZone("UTC")); + |$v.setFirstDayOfWeek($cal.MONDAY); + |$v.setMinimalDaysInFirstWeek(4); + """.stripMargin) s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.WEEK_OF_YEAR); - """ + |$c.setTimeInMillis($time * 1000L * 3600L * 24L); + |${ev.value} = $c.get($cal.WEEK_OF_YEAR); + """.stripMargin }) } } @@ -1014,12 +1013,12 @@ case class FromUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} @@ -1190,12 +1189,12 @@ case class ToUTCTimestamp(left: Expression, right: Expression) |long ${ev.value} = 0; """.stripMargin) } else { - val tzTerm = ctx.freshName("tz") - val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTerm = ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getTimeZone("$tz");""") + val utcTerm = ctx.addMutableState(tzClass, "utc", + v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index cd38783a731ad..1cd73a92a8635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -199,8 +199,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. - val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") + val rowData = ctx.addMutableState("InternalRow[]", "rows", + v => s"$v = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => @@ -217,7 +217,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") + v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 294cdcb2e9546..b4f895fffda38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.freshName("coalesceTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4bd395eadcf19..a59aad5be8715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -62,15 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, resultIsNull) + val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue) + val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") argValue } @@ -548,7 +546,7 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue) + ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -644,7 +642,7 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -808,10 +806,11 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue) + ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue) + ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, + useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -844,7 +843,8 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { "" @@ -994,8 +994,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key) - ctx.addMutableState(valueElementJavaType, value) + ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) + ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -1031,14 +1031,14 @@ case class ExternalMapToCatalyst private( } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull) + ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1148,7 +1148,6 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1159,14 +1158,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """.stripMargin) // Code to serialize. val input = child.genCode(ctx) @@ -1194,7 +1193,6 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { if (kryo) { (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) @@ -1205,14 +1203,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializerInit = s""" - if ($env == null) { - $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - } else { - $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - } - """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v => + s""" + |if ($env == null) { + | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + |} else { + | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + |} + """.stripMargin) // Code to deserialize. val input = child.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index b4aefe6cff73e..8bc936fcbfc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -77,9 +77,8 @@ case class Rand(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" @@ -112,9 +111,8 @@ case class Randn(child: Expression) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm) + val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 53d7096dd87d3..fa5425c77ebba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -112,15 +112,15 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + // inline mutable state since not many Like operations in a task + val pattern = ctx.addMutableState(patternClass, "patternLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -139,6 +139,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi """) } } else { + val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -187,15 +188,15 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") if (right.foldable) { val rVal = right.eval() if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") + // inline mutable state since not many RLike operations in a task + val pattern = ctx.addMutableState(patternClass, "patternRLike", + v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -215,6 +216,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } } else { val rightStr = ctx.freshName("rightStr") + val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = ${eval2}.toString(); @@ -316,11 +318,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") val termResult = ctx.freshName("termResult") val classNamePattern = classOf[Pattern].getCanonicalName @@ -328,11 +325,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val matcher = ctx.freshName("matcher") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val termLastReplacement = ctx.addMutableState("String", "lastReplacement") + val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -414,14 +410,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8c4d2fd686be5..c02c41db1668e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -291,8 +291,7 @@ case class Elt(children: Seq[Expression]) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.freshName("stringVal") - ctx.addMutableState(ctx.javaType(dataType), stringVal) + val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" @@ -532,14 +531,11 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val termLastMatching = ctx.freshName("lastMatching") - val termLastReplace = ctx.freshName("lastReplace") - val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") - ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") + val termLastMatching = ctx.addMutableState("UTF8String", "lastMatching") + val termLastReplace = ctx.addMutableState("UTF8String", "lastReplace") + val termDict = ctx.addMutableState(classNameDict, "dict") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { @@ -2065,15 +2061,12 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val lastDValue = ctx.freshName("lastDValue") - val pattern = ctx.freshName("pattern") - val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState(ctx.JAVA_INT, lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, - s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") + val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val numberFormat = ctx.addMutableState(df, "numberFormat", + v => s"""$v = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index be638d80e45d8..6edb4348f8309 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -348,10 +348,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22704: Least and greatest use less global variables") { val ctx1 = new CodegenContext() Least(Seq(Literal(1), Literal(1))).genCode(ctx1) - assert(ctx1.mutableStates.size == 1) + assert(ctx1.inlinedMutableStates.size == 1) val ctx2 = new CodegenContext() Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) - assert(ctx2.mutableStates.size == 1) + assert(ctx2.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 65617be05a434..1dd040e4696a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -851,6 +851,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext cast("1", IntegerType).genCode(ctx) cast("2", LongType).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a969811019161..b1a44528e64d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -385,20 +385,43 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext val schema = new StructType().add("a", IntegerType).add("b", StringType) CreateExternalRow(Seq(Literal(1), Literal("x")), schema).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22696: InitializeJavaBean should not use global variables") { val ctx = new CodegenContext InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), Map("add" -> Literal(1))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("SPARK-22716: addReferenceObj should not add mutable states") { val ctx = new CodegenContext val foo = new Object() ctx.addReferenceObj("foo", foo) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) + } + + test("SPARK-18016: define mutable states by using an array") { + val ctx1 = new CodegenContext + for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { + ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + } + assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + // When the number of primitive type mutable states is over the threshold, others are + // allocated into an array + assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) + + val ctx2 = new CodegenContext + for (i <- 1 to CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) { + ctx2.addMutableState("InternalRow[]", "r", v => s"$v = new InternalRow[$i];") + } + // When the number of non-primitive type mutable states is over the threshold, others are + // allocated into a new array + assert(ctx2.inlinedMutableStates.isEmpty) + assert(ctx2.arrayCompactedMutableStates.get("InternalRow[]").get.arrayNames.size == 2) + assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10) + assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 6dfca7d73a3df..84190f0bd5f7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -304,6 +304,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: CreateNamedStruct should not use global variables") { val ctx = new CodegenContext CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 60d84aae1fa3f..a099119732e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -150,6 +150,6 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-22705: case when should use less global variables") { val ctx = new CodegenContext() CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index a23cd95632770..cc6c15cb2c909 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -159,7 +159,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: Coalesce should use less global variables") { val ctx = new CodegenContext() Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) - assert(ctx.mutableStates.size == 1) + assert(ctx.inlinedMutableStates.size == 1) } test("AtLeastNNonNulls should not throw 64kb exception") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 15cb0bea08f17..8a8f8e10225fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -249,7 +249,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22705: In should use less global variables") { val ctx = new CodegenContext() In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } test("INSET") { @@ -440,6 +440,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 4fa61fbaf66c2..2a0a42c65b086 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -183,8 +183,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) - // are always required - assert(ctx.mutableStates.length == 4) + // are always required, which are allocated in type-based global array + assert(ctx.inlinedMutableStates.length == 0) + assert(ctx.mutableStateInitCode.length == 4) } test("RegexExtract") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 70dea4b39d55d..10e3ffd0dff97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,6 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - assert(ctx.mutableStates.isEmpty) + assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 6031bdf19e957..2c45b3b0c73d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -219,4 +219,31 @@ class GeneratedProjectionSuite extends SparkFunSuite { // - one is the mutableRow assert(globalVariables.length == 3) } + + test("SPARK-18016: generated projections on wider table requiring state compaction") { + val N = 6000 + val wideRow1 = new GenericInternalRow(new Array[Any](N)) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + Array.tabulate[Any](N)(i => UTF8String.fromString(i.toString))) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(nested) + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e3675367d78e4..0d11958876ce9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -168,7 +168,7 @@ class ComplexTypesSuite extends PlanTest{ test("SPARK-22570: CreateArray should not create a lot of global variables") { val ctx = new CodegenContext CreateArray(Seq(Literal(1))).genCode(ctx) - assert(ctx.mutableStates.length == 0) + assert(ctx.inlinedMutableStates.length == 0) } test("simplify map ops") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbdea..782cec5e292ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -68,30 +68,26 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { */ // TODO: return ColumnarBatch.Rows instead override protected def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", + v => s"$v = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState(ctx.JAVA_LONG, scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName - val batch = ctx.freshName("batch") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.freshName("batchIdx") - ctx.addMutableState(ctx.JAVA_INT, idx, s"$idx = 0;") - val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( - Seq.fill(colVars.size)(classOf[ColumnVector].getName)) - val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { - case ((name, columnVectorClz), i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ($columnVectorClz) $batch.column($i);" - } + Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) + val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { + case (columnVectorClz, i) => + val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") + (name, s"$name = ($columnVectorClz) $batch.column($i);") + }.unzip val nextBatch = ctx.freshName("nextBatch") val nextBatchFuncName = ctx.addNewFunction(nextBatch, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 27c7dc3ee4533..d1ff82c7c06bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -110,8 +110,7 @@ case class RowDataSourceScanExec( override protected def doProduce(ctx: CodegenContext): String = { val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -353,8 +352,7 @@ case class FileSourceScanExec( } val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input - val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index c0e21343ae623..daff3c49e7517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -133,20 +133,18 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.freshName("needToSort") - ctx.addMutableState(ctx.JAVA_BOOLEAN, needToSort, s"$needToSort = true;") + val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - sorterVariable = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, - s"$sorterVariable = $thisPlan.createSorter();") - val metrics = ctx.freshName("metrics") - ctx.addMutableState(classOf[TaskMetrics].getName, metrics, - s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") - val sortedIterator = ctx.freshName("sortedIter") - ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + // inline mutable state since not many Sort operations in a task + sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", + v => s"$v = $thisPlan.createSorter();", forceInline = true) + val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", "sortedIter", + forceInline = true) val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 7166b7771e4db..9e7008d1e0c31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,9 +282,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp } override def doProduce(ctx: CodegenContext): String = { - val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one input RDD. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + // inline mutable state since an inputAdaptor in a task + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + forceInline = true) val row = ctx.freshName("row") s""" | while ($input.hasNext() && !stopEarly()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e72..b1af360d85095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,8 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -187,10 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -568,8 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -583,42 +579,41 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) // Create a name for the iterator from the fast hash map. - val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") - if (isFastHashMapEnabled) { + val iterTermForFastHashMap = if (isFastHashMapEnabled) { // Generates the fast hash map class and creates the fash hash map term. - fastHashMapTerm = ctx.freshName("fastHashMap") val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", iterTermForFastHashMap) + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", + v => s"$v = new $fastHashMapClassName();") + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, - s"$fastHashMapTerm = new $fastHashMapClassName(" + + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", + v => s"$v = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap) + "fastHashMapIter") } } // Create a name for the iterator from the regular hash map. - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) + // inline mutable state since not many aggregation operations in a task + val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, + "mapIter", forceInline = true) // create hashMap - hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") - sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) + hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", + v => s"$v = $thisPlan.createHashMap();") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", + forceInline = true) val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") @@ -758,8 +753,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 85b4529501ea8..1c613b19c4ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -46,10 +46,8 @@ abstract class HashMapGenerator( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) - ctx.addMutableState(ctx.javaType(e.dataType), value) + val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c9a15147e30d0..78137d3f97cfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,29 +279,30 @@ case class SampleExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val sampler = ctx.freshName("sampler") if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - val initSamplerFuncName = ctx.addNewFunction(initSampler, - s""" - | private void $initSampler() { - | $sampler = new $samplerClass($upperBound - $lowerBound, false); - | java.util.Random random = new java.util.Random(${seed}L); - | long randomSeed = random.nextLong(); - | int loopCount = 0; - | while (loopCount < partitionIndex) { - | randomSeed = random.nextLong(); - | loopCount += 1; - | } - | $sampler.setSeed(randomSeed); - | } - """.stripMargin.trim) - - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSamplerFuncName();") + // inline mutable state since not many Sample operations in a task + val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", + v => { + val initSamplerFuncName = ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $v = new $samplerClass($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $v.setSeed(randomSeed); + | } + """.stripMargin.trim) + s"$initSamplerFuncName();" + }, forceInline = true) val samplingCount = ctx.freshName("samplingCount") s""" @@ -313,10 +314,10 @@ case class SampleExec( """.stripMargin.trim } else { val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName - ctx.addMutableState(s"$samplerClass", sampler, - s""" - | $sampler = new $samplerClass($lowerBound, $upperBound, false); - | $sampler.setSeed(${seed}L + partitionIndex); + val sampler = ctx.addMutableState(s"$samplerClass", "sampler", + v => s""" + | $v = new $samplerClass($lowerBound, $upperBound, false); + | $v.setSeed(${seed}L + partitionIndex); """.stripMargin.trim) s""" @@ -363,20 +364,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.freshName("initRange") - ctx.addMutableState(ctx.JAVA_BOOLEAN, initTerm, s"$initTerm = false;") - val number = ctx.freshName("number") - ctx.addMutableState(ctx.JAVA_LONG, number, s"$number = 0L;") + val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(ctx.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - val taskContext = ctx.freshName("taskContext") - ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") - val inputMetrics = ctx.freshName("inputMetrics") - ctx.addMutableState("InputMetrics", inputMetrics, - s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") + // inline mutable state since not many Range operations in a task + val taskContext = ctx.addMutableState("TaskContext", "taskContext", + v => s"$v = TaskContext.get();", forceInline = true) + val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", + v => s"$v = $taskContext.taskMetrics().inputMetrics();", forceInline = true) // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -386,12 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState(ctx.JAVA_LONG, batchEnd, s"$batchEnd = 0;") + val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState(ctx.JAVA_LONG, numElementsTodo, s"$numElementsTodo = 0L;") + val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") @@ -440,10 +437,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } """.stripMargin) - val input = ctx.freshName("input") - // Right now, Range is only used when there is one upstream. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff5dd707f0b38..4f28eeb725cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -70,7 +70,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val ctx = newCodeGenContext() val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => - val accessorName = ctx.freshName("accessor") val accessorCls = dt match { case NullType => classOf[NullColumnAccessor].getName case BooleanType => classOf[BooleanColumnAccessor].getName @@ -89,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName) + val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index c96ed6ef41016..ee763e23415cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -134,19 +134,18 @@ case class BroadcastHashJoinExec( // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - val addTaskListener = genTaskListener(avgHashProbe, relationTerm) - ctx.addMutableState(clsName, relationTerm, - s""" - | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); - | $addTaskListener - """.stripMargin) + // inline mutable state since not many join operations in a task + val relationTerm = ctx.addMutableState(clsName, "relation", + v => s""" + | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); + | incPeakExecutionMemory($v.estimatedSize()); + | ${genTaskListener(avgHashProbe, v)} + """.stripMargin, forceInline = true) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 554b73181116c..073730462a75f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,10 +422,9 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow) - val rightRow = ctx.freshName("rightRow") - ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + // inline mutable state since not many join operations in a task + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -436,14 +435,13 @@ case class SortMergeJoinExec( val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) // A list to hold all matched rows from right side. - val matches = ctx.freshName("matches") val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, - s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -578,10 +576,11 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - val leftInput = ctx.freshName("leftInput") - ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") - val rightInput = ctx.freshName("rightInput") - ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + // inline mutable state since not many join operations in a task + val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + v => s"$v = inputs[0];", forceInline = true) + val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + v => s"$v = inputs[1];", forceInline = true) val (leftRow, matches) = genScanner(ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index a8556f6ba107a..cccee63bc0680 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,8 +71,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState(ctx.JAVA_BOOLEAN, stopEarly, s"$stopEarly = false;") + val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -80,8 +79,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.freshName("count") - ctx.addMutableState(ctx.JAVA_INT, countTerm, s"$countTerm = 0;") + val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From ef10f452e62c77d0434e80f7266f6685eb1bcb2c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 19 Dec 2017 09:05:47 -0800 Subject: [PATCH 455/712] [SPARK-21652][SQL][FOLLOW-UP] Fix rule conflict caused by InferFiltersFromConstraints ## What changes were proposed in this pull request? The optimizer rule `InferFiltersFromConstraints` could trigger our batch `Operator Optimizations` exceeds the max iteration limit (i.e., 100) so that the final plan might not be properly optimized. The rule `InferFiltersFromConstraints` could conflict with the other Filter/Join predicate reduction rules. Thus, we need to separate `InferFiltersFromConstraints` from the other rules. This PR is to separate `InferFiltersFromConstraints ` from the main batch `Operator Optimizations` . ## How was this patch tested? The existing test cases. Author: gatorsmile Closes #19149 from gatorsmile/inferFilterRule. --- .../sql/catalyst/optimizer/Optimizer.scala | 115 ++++++++++-------- 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5acadf8cf330e..6a4d1e997c3c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -47,7 +47,62 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) def batches: Seq[Batch] = { - Batch("Eliminate Distinct", Once, EliminateDistinct) :: + val operatorOptimizationRuleSet = + Seq( + // Operator push down + PushProjectionThroughUnion, + ReorderJoin, + EliminateOuterJoin, + PushPredicateThroughJoin, + PushDownPredicate, + LimitPushDown, + ColumnPruning, + InferFiltersFromConstraints, + // Operator combine + CollapseRepartition, + CollapseProject, + CollapseWindow, + CombineFilters, + CombineLimits, + CombineUnions, + // Constant folding and strength reduction + NullPropagation, + ConstantPropagation, + FoldablePropagation, + OptimizeIn, + ConstantFolding, + ReorderAssociativeOperator, + LikeSimplification, + BooleanSimplification, + SimplifyConditionals, + RemoveDispensableExpressions, + SimplifyBinaryComparison, + PruneFilters, + EliminateSorts, + SimplifyCasts, + SimplifyCaseConversionExpressions, + RewriteCorrelatedScalarSubquery, + EliminateSerialization, + RemoveRedundantAliases, + RemoveRedundantProject, + SimplifyCreateStructOps, + SimplifyCreateArrayOps, + SimplifyCreateMapOps, + CombineConcats) ++ + extendedOperatorOptimizationRules + + val operatorOptimizationBatch: Seq[Batch] = { + val rulesWithoutInferFiltersFromConstraints = + operatorOptimizationRuleSet.filterNot(_ == InferFiltersFromConstraints) + Batch("Operator Optimization before Inferring Filters", fixedPoint, + rulesWithoutInferFiltersFromConstraints: _*) :: + Batch("Infer Filters", Once, + InferFiltersFromConstraints) :: + Batch("Operator Optimization after Inferring Filters", fixedPoint, + rulesWithoutInferFiltersFromConstraints: _*) :: Nil + } + + (Batch("Eliminate Distinct", Once, EliminateDistinct) :: // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -81,68 +136,26 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, - RemoveRepetitionFromGroupExpressions) :: - Batch("Operator Optimizations", fixedPoint, Seq( - // Operator push down - PushProjectionThroughUnion, - ReorderJoin, - EliminateOuterJoin, - InferFiltersFromConstraints, - BooleanSimplification, - PushPredicateThroughJoin, - PushDownPredicate, - LimitPushDown, - ColumnPruning, - // Operator combine - CollapseRepartition, - CollapseProject, - CollapseWindow, - CombineFilters, - CombineLimits, - CombineUnions, - // Constant folding and strength reduction - NullPropagation, - ConstantPropagation, - FoldablePropagation, - OptimizeIn, - ConstantFolding, - ReorderAssociativeOperator, - LikeSimplification, - BooleanSimplification, - SimplifyConditionals, - RemoveDispensableExpressions, - SimplifyBinaryComparison, - PruneFilters, - EliminateSorts, - SimplifyCasts, - SimplifyCaseConversionExpressions, - RewriteCorrelatedScalarSubquery, - EliminateSerialization, - RemoveRedundantAliases, - RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, - CombineConcats) ++ - extendedOperatorOptimizationRules: _*) :: + RemoveRepetitionFromGroupExpressions) :: Nil ++ + operatorOptimizationBatch) :+ Batch("Join Reorder", Once, - CostBasedJoinReorder) :: + CostBasedJoinReorder) :+ Batch("Decimal Optimizations", fixedPoint, - DecimalAggregates) :: + DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, - CombineTypedFilters) :: + CombineTypedFilters) :+ Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, - PropagateEmptyRelation) :: + PropagateEmptyRelation) :+ // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, - CheckCartesianProducts) :: + CheckCartesianProducts) :+ Batch("RewriteSubquery", Once, RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) :: Nil + RemoveRedundantProject) } /** From 6129ffa11ea62437a25844455e87a1e4c21b030f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 19 Dec 2017 11:56:22 -0800 Subject: [PATCH 456/712] [SPARK-22821][TEST] Basic tests for WidenSetOperationTypes, BooleanEquality, StackCoercion and Division ## What changes were proposed in this pull request? Test Coverage for `WidenSetOperationTypes`, `BooleanEquality`, `StackCoercion` and `Division`, this is a Sub-tasks for [SPARK-22722](https://issues.apache.org/jira/browse/SPARK-22722). ## How was this patch tested? N/A Author: Yuming Wang Closes #20006 from wangyum/SPARK-22821. --- .../typeCoercion/native/booleanEquality.sql | 122 ++ .../inputs/typeCoercion/native/division.sql | 174 +++ .../native/widenSetOperationTypes.sql | 175 +++ .../native/booleanEquality.sql.out | 802 ++++++++++ .../typeCoercion/native/division.sql.out | 1242 ++++++++++++++++ .../native/widenSetOperationTypes.sql.out | 1305 +++++++++++++++++ 6 files changed, 3820 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql new file mode 100644 index 0000000000000..442f2355f8e3a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/booleanEquality.sql @@ -0,0 +1,122 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT true = cast(1 as tinyint) FROM t; +SELECT true = cast(1 as smallint) FROM t; +SELECT true = cast(1 as int) FROM t; +SELECT true = cast(1 as bigint) FROM t; +SELECT true = cast(1 as float) FROM t; +SELECT true = cast(1 as double) FROM t; +SELECT true = cast(1 as decimal(10, 0)) FROM t; +SELECT true = cast(1 as string) FROM t; +SELECT true = cast('1' as binary) FROM t; +SELECT true = cast(1 as boolean) FROM t; +SELECT true = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT true = cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT true <=> cast(1 as tinyint) FROM t; +SELECT true <=> cast(1 as smallint) FROM t; +SELECT true <=> cast(1 as int) FROM t; +SELECT true <=> cast(1 as bigint) FROM t; +SELECT true <=> cast(1 as float) FROM t; +SELECT true <=> cast(1 as double) FROM t; +SELECT true <=> cast(1 as decimal(10, 0)) FROM t; +SELECT true <=> cast(1 as string) FROM t; +SELECT true <=> cast('1' as binary) FROM t; +SELECT true <=> cast(1 as boolean) FROM t; +SELECT true <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT true <=> cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) = true FROM t; +SELECT cast(1 as smallint) = true FROM t; +SELECT cast(1 as int) = true FROM t; +SELECT cast(1 as bigint) = true FROM t; +SELECT cast(1 as float) = true FROM t; +SELECT cast(1 as double) = true FROM t; +SELECT cast(1 as decimal(10, 0)) = true FROM t; +SELECT cast(1 as string) = true FROM t; +SELECT cast('1' as binary) = true FROM t; +SELECT cast(1 as boolean) = true FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = true FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = true FROM t; + +SELECT cast(1 as tinyint) <=> true FROM t; +SELECT cast(1 as smallint) <=> true FROM t; +SELECT cast(1 as int) <=> true FROM t; +SELECT cast(1 as bigint) <=> true FROM t; +SELECT cast(1 as float) <=> true FROM t; +SELECT cast(1 as double) <=> true FROM t; +SELECT cast(1 as decimal(10, 0)) <=> true FROM t; +SELECT cast(1 as string) <=> true FROM t; +SELECT cast('1' as binary) <=> true FROM t; +SELECT cast(1 as boolean) <=> true FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> true FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> true FROM t; + +SELECT false = cast(0 as tinyint) FROM t; +SELECT false = cast(0 as smallint) FROM t; +SELECT false = cast(0 as int) FROM t; +SELECT false = cast(0 as bigint) FROM t; +SELECT false = cast(0 as float) FROM t; +SELECT false = cast(0 as double) FROM t; +SELECT false = cast(0 as decimal(10, 0)) FROM t; +SELECT false = cast(0 as string) FROM t; +SELECT false = cast('0' as binary) FROM t; +SELECT false = cast(0 as boolean) FROM t; +SELECT false = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT false = cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT false <=> cast(0 as tinyint) FROM t; +SELECT false <=> cast(0 as smallint) FROM t; +SELECT false <=> cast(0 as int) FROM t; +SELECT false <=> cast(0 as bigint) FROM t; +SELECT false <=> cast(0 as float) FROM t; +SELECT false <=> cast(0 as double) FROM t; +SELECT false <=> cast(0 as decimal(10, 0)) FROM t; +SELECT false <=> cast(0 as string) FROM t; +SELECT false <=> cast('0' as binary) FROM t; +SELECT false <=> cast(0 as boolean) FROM t; +SELECT false <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT false <=> cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(0 as tinyint) = false FROM t; +SELECT cast(0 as smallint) = false FROM t; +SELECT cast(0 as int) = false FROM t; +SELECT cast(0 as bigint) = false FROM t; +SELECT cast(0 as float) = false FROM t; +SELECT cast(0 as double) = false FROM t; +SELECT cast(0 as decimal(10, 0)) = false FROM t; +SELECT cast(0 as string) = false FROM t; +SELECT cast('0' as binary) = false FROM t; +SELECT cast(0 as boolean) = false FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = false FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = false FROM t; + +SELECT cast(0 as tinyint) <=> false FROM t; +SELECT cast(0 as smallint) <=> false FROM t; +SELECT cast(0 as int) <=> false FROM t; +SELECT cast(0 as bigint) <=> false FROM t; +SELECT cast(0 as float) <=> false FROM t; +SELECT cast(0 as double) <=> false FROM t; +SELECT cast(0 as decimal(10, 0)) <=> false FROM t; +SELECT cast(0 as string) <=> false FROM t; +SELECT cast('0' as binary) <=> false FROM t; +SELECT cast(0 as boolean) <=> false FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> false FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> false FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql new file mode 100644 index 0000000000000..d669740ddd9ca --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/division.sql @@ -0,0 +1,174 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT cast(1 as tinyint) / cast(1 as tinyint) FROM t; +SELECT cast(1 as tinyint) / cast(1 as smallint) FROM t; +SELECT cast(1 as tinyint) / cast(1 as int) FROM t; +SELECT cast(1 as tinyint) / cast(1 as bigint) FROM t; +SELECT cast(1 as tinyint) / cast(1 as float) FROM t; +SELECT cast(1 as tinyint) / cast(1 as double) FROM t; +SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) / cast(1 as string) FROM t; +SELECT cast(1 as tinyint) / cast('1' as binary) FROM t; +SELECT cast(1 as tinyint) / cast(1 as boolean) FROM t; +SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as smallint) / cast(1 as tinyint) FROM t; +SELECT cast(1 as smallint) / cast(1 as smallint) FROM t; +SELECT cast(1 as smallint) / cast(1 as int) FROM t; +SELECT cast(1 as smallint) / cast(1 as bigint) FROM t; +SELECT cast(1 as smallint) / cast(1 as float) FROM t; +SELECT cast(1 as smallint) / cast(1 as double) FROM t; +SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) / cast(1 as string) FROM t; +SELECT cast(1 as smallint) / cast('1' as binary) FROM t; +SELECT cast(1 as smallint) / cast(1 as boolean) FROM t; +SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as int) / cast(1 as tinyint) FROM t; +SELECT cast(1 as int) / cast(1 as smallint) FROM t; +SELECT cast(1 as int) / cast(1 as int) FROM t; +SELECT cast(1 as int) / cast(1 as bigint) FROM t; +SELECT cast(1 as int) / cast(1 as float) FROM t; +SELECT cast(1 as int) / cast(1 as double) FROM t; +SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) / cast(1 as string) FROM t; +SELECT cast(1 as int) / cast('1' as binary) FROM t; +SELECT cast(1 as int) / cast(1 as boolean) FROM t; +SELECT cast(1 as int) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as int) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as bigint) / cast(1 as tinyint) FROM t; +SELECT cast(1 as bigint) / cast(1 as smallint) FROM t; +SELECT cast(1 as bigint) / cast(1 as int) FROM t; +SELECT cast(1 as bigint) / cast(1 as bigint) FROM t; +SELECT cast(1 as bigint) / cast(1 as float) FROM t; +SELECT cast(1 as bigint) / cast(1 as double) FROM t; +SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) / cast(1 as string) FROM t; +SELECT cast(1 as bigint) / cast('1' as binary) FROM t; +SELECT cast(1 as bigint) / cast(1 as boolean) FROM t; +SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as float) / cast(1 as tinyint) FROM t; +SELECT cast(1 as float) / cast(1 as smallint) FROM t; +SELECT cast(1 as float) / cast(1 as int) FROM t; +SELECT cast(1 as float) / cast(1 as bigint) FROM t; +SELECT cast(1 as float) / cast(1 as float) FROM t; +SELECT cast(1 as float) / cast(1 as double) FROM t; +SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) / cast(1 as string) FROM t; +SELECT cast(1 as float) / cast('1' as binary) FROM t; +SELECT cast(1 as float) / cast(1 as boolean) FROM t; +SELECT cast(1 as float) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as float) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as double) / cast(1 as tinyint) FROM t; +SELECT cast(1 as double) / cast(1 as smallint) FROM t; +SELECT cast(1 as double) / cast(1 as int) FROM t; +SELECT cast(1 as double) / cast(1 as bigint) FROM t; +SELECT cast(1 as double) / cast(1 as float) FROM t; +SELECT cast(1 as double) / cast(1 as double) FROM t; +SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) / cast(1 as string) FROM t; +SELECT cast(1 as double) / cast('1' as binary) FROM t; +SELECT cast(1 as double) / cast(1 as boolean) FROM t; +SELECT cast(1 as double) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as double) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as string) / cast(1 as tinyint) FROM t; +SELECT cast(1 as string) / cast(1 as smallint) FROM t; +SELECT cast(1 as string) / cast(1 as int) FROM t; +SELECT cast(1 as string) / cast(1 as bigint) FROM t; +SELECT cast(1 as string) / cast(1 as float) FROM t; +SELECT cast(1 as string) / cast(1 as double) FROM t; +SELECT cast(1 as string) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as string) / cast(1 as string) FROM t; +SELECT cast(1 as string) / cast('1' as binary) FROM t; +SELECT cast(1 as string) / cast(1 as boolean) FROM t; +SELECT cast(1 as string) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as string) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('1' as binary) / cast(1 as tinyint) FROM t; +SELECT cast('1' as binary) / cast(1 as smallint) FROM t; +SELECT cast('1' as binary) / cast(1 as int) FROM t; +SELECT cast('1' as binary) / cast(1 as bigint) FROM t; +SELECT cast('1' as binary) / cast(1 as float) FROM t; +SELECT cast('1' as binary) / cast(1 as double) FROM t; +SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) / cast(1 as string) FROM t; +SELECT cast('1' as binary) / cast('1' as binary) FROM t; +SELECT cast('1' as binary) / cast(1 as boolean) FROM t; +SELECT cast('1' as binary) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('1' as binary) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as boolean) / cast(1 as tinyint) FROM t; +SELECT cast(1 as boolean) / cast(1 as smallint) FROM t; +SELECT cast(1 as boolean) / cast(1 as int) FROM t; +SELECT cast(1 as boolean) / cast(1 as bigint) FROM t; +SELECT cast(1 as boolean) / cast(1 as float) FROM t; +SELECT cast(1 as boolean) / cast(1 as double) FROM t; +SELECT cast(1 as boolean) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as boolean) / cast(1 as string) FROM t; +SELECT cast(1 as boolean) / cast('1' as binary) FROM t; +SELECT cast(1 as boolean) / cast(1 as boolean) FROM t; +SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as tinyint) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as smallint) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as int) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as bigint) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as float) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as double) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as string) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('1' as binary) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as boolean) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as tinyint) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as smallint) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as int) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as bigint) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as float) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as double) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as string) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast('1' as binary) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as boolean) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00' as date) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql new file mode 100644 index 0000000000000..66e9689850d93 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/widenSetOperationTypes.sql @@ -0,0 +1,175 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +-- UNION +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as tinyint) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as smallint) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as int) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as bigint) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as float) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as double) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as string) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2' as binary) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as boolean) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out new file mode 100644 index 0000000000000..46775d79ff4a2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/booleanEquality.sql.out @@ -0,0 +1,802 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 97 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT true = cast(1 as tinyint) FROM t +-- !query 1 schema +struct<(CAST(true AS TINYINT) = CAST(1 AS TINYINT)):boolean> +-- !query 1 output +true + + +-- !query 2 +SELECT true = cast(1 as smallint) FROM t +-- !query 2 schema +struct<(CAST(true AS SMALLINT) = CAST(1 AS SMALLINT)):boolean> +-- !query 2 output +true + + +-- !query 3 +SELECT true = cast(1 as int) FROM t +-- !query 3 schema +struct<(CAST(true AS INT) = CAST(1 AS INT)):boolean> +-- !query 3 output +true + + +-- !query 4 +SELECT true = cast(1 as bigint) FROM t +-- !query 4 schema +struct<(CAST(true AS BIGINT) = CAST(1 AS BIGINT)):boolean> +-- !query 4 output +true + + +-- !query 5 +SELECT true = cast(1 as float) FROM t +-- !query 5 schema +struct<(CAST(true AS FLOAT) = CAST(1 AS FLOAT)):boolean> +-- !query 5 output +true + + +-- !query 6 +SELECT true = cast(1 as double) FROM t +-- !query 6 schema +struct<(CAST(true AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 6 output +true + + +-- !query 7 +SELECT true = cast(1 as decimal(10, 0)) FROM t +-- !query 7 schema +struct<(CAST(true AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 7 output +true + + +-- !query 8 +SELECT true = cast(1 as string) FROM t +-- !query 8 schema +struct<(true = CAST(CAST(1 AS STRING) AS BOOLEAN)):boolean> +-- !query 8 output +true + + +-- !query 9 +SELECT true = cast('1' as binary) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(true = CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7 + + +-- !query 10 +SELECT true = cast(1 as boolean) FROM t +-- !query 10 schema +struct<(true = CAST(1 AS BOOLEAN)):boolean> +-- !query 10 output +true + + +-- !query 11 +SELECT true = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(true = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 12 +SELECT true = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(true = CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7 + + +-- !query 13 +SELECT true <=> cast(1 as tinyint) FROM t +-- !query 13 schema +struct<(CAST(true AS TINYINT) <=> CAST(1 AS TINYINT)):boolean> +-- !query 13 output +true + + +-- !query 14 +SELECT true <=> cast(1 as smallint) FROM t +-- !query 14 schema +struct<(CAST(true AS SMALLINT) <=> CAST(1 AS SMALLINT)):boolean> +-- !query 14 output +true + + +-- !query 15 +SELECT true <=> cast(1 as int) FROM t +-- !query 15 schema +struct<(CAST(true AS INT) <=> CAST(1 AS INT)):boolean> +-- !query 15 output +true + + +-- !query 16 +SELECT true <=> cast(1 as bigint) FROM t +-- !query 16 schema +struct<(CAST(true AS BIGINT) <=> CAST(1 AS BIGINT)):boolean> +-- !query 16 output +true + + +-- !query 17 +SELECT true <=> cast(1 as float) FROM t +-- !query 17 schema +struct<(CAST(true AS FLOAT) <=> CAST(1 AS FLOAT)):boolean> +-- !query 17 output +true + + +-- !query 18 +SELECT true <=> cast(1 as double) FROM t +-- !query 18 schema +struct<(CAST(true AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 18 output +true + + +-- !query 19 +SELECT true <=> cast(1 as decimal(10, 0)) FROM t +-- !query 19 schema +struct<(CAST(true AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 19 output +true + + +-- !query 20 +SELECT true <=> cast(1 as string) FROM t +-- !query 20 schema +struct<(true <=> CAST(CAST(1 AS STRING) AS BOOLEAN)):boolean> +-- !query 20 output +true + + +-- !query 21 +SELECT true <=> cast('1' as binary) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(true <=> CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7 + + +-- !query 22 +SELECT true <=> cast(1 as boolean) FROM t +-- !query 22 schema +struct<(true <=> CAST(1 AS BOOLEAN)):boolean> +-- !query 22 output +true + + +-- !query 23 +SELECT true <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(true <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 24 +SELECT true <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(true <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(true <=> CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7 + + +-- !query 25 +SELECT cast(1 as tinyint) = true FROM t +-- !query 25 schema +struct<(CAST(1 AS TINYINT) = CAST(true AS TINYINT)):boolean> +-- !query 25 output +true + + +-- !query 26 +SELECT cast(1 as smallint) = true FROM t +-- !query 26 schema +struct<(CAST(1 AS SMALLINT) = CAST(true AS SMALLINT)):boolean> +-- !query 26 output +true + + +-- !query 27 +SELECT cast(1 as int) = true FROM t +-- !query 27 schema +struct<(CAST(1 AS INT) = CAST(true AS INT)):boolean> +-- !query 27 output +true + + +-- !query 28 +SELECT cast(1 as bigint) = true FROM t +-- !query 28 schema +struct<(CAST(1 AS BIGINT) = CAST(true AS BIGINT)):boolean> +-- !query 28 output +true + + +-- !query 29 +SELECT cast(1 as float) = true FROM t +-- !query 29 schema +struct<(CAST(1 AS FLOAT) = CAST(true AS FLOAT)):boolean> +-- !query 29 output +true + + +-- !query 30 +SELECT cast(1 as double) = true FROM t +-- !query 30 schema +struct<(CAST(1 AS DOUBLE) = CAST(true AS DOUBLE)):boolean> +-- !query 30 output +true + + +-- !query 31 +SELECT cast(1 as decimal(10, 0)) = true FROM t +-- !query 31 schema +struct<(CAST(1 AS DECIMAL(10,0)) = CAST(true AS DECIMAL(10,0))):boolean> +-- !query 31 output +true + + +-- !query 32 +SELECT cast(1 as string) = true FROM t +-- !query 32 schema +struct<(CAST(CAST(1 AS STRING) AS BOOLEAN) = true):boolean> +-- !query 32 output +true + + +-- !query 33 +SELECT cast('1' as binary) = true FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = true)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = true)' (binary and boolean).; line 1 pos 7 + + +-- !query 34 +SELECT cast(1 as boolean) = true FROM t +-- !query 34 schema +struct<(CAST(1 AS BOOLEAN) = true):boolean> +-- !query 34 output +true + + +-- !query 35 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = true FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = true)' (timestamp and boolean).; line 1 pos 7 + + +-- !query 36 +SELECT cast('2017-12-11 09:30:00' as date) = true FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = true)' (date and boolean).; line 1 pos 7 + + +-- !query 37 +SELECT cast(1 as tinyint) <=> true FROM t +-- !query 37 schema +struct<(CAST(1 AS TINYINT) <=> CAST(true AS TINYINT)):boolean> +-- !query 37 output +true + + +-- !query 38 +SELECT cast(1 as smallint) <=> true FROM t +-- !query 38 schema +struct<(CAST(1 AS SMALLINT) <=> CAST(true AS SMALLINT)):boolean> +-- !query 38 output +true + + +-- !query 39 +SELECT cast(1 as int) <=> true FROM t +-- !query 39 schema +struct<(CAST(1 AS INT) <=> CAST(true AS INT)):boolean> +-- !query 39 output +true + + +-- !query 40 +SELECT cast(1 as bigint) <=> true FROM t +-- !query 40 schema +struct<(CAST(1 AS BIGINT) <=> CAST(true AS BIGINT)):boolean> +-- !query 40 output +true + + +-- !query 41 +SELECT cast(1 as float) <=> true FROM t +-- !query 41 schema +struct<(CAST(1 AS FLOAT) <=> CAST(true AS FLOAT)):boolean> +-- !query 41 output +true + + +-- !query 42 +SELECT cast(1 as double) <=> true FROM t +-- !query 42 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(true AS DOUBLE)):boolean> +-- !query 42 output +true + + +-- !query 43 +SELECT cast(1 as decimal(10, 0)) <=> true FROM t +-- !query 43 schema +struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(true AS DECIMAL(10,0))):boolean> +-- !query 43 output +true + + +-- !query 44 +SELECT cast(1 as string) <=> true FROM t +-- !query 44 schema +struct<(CAST(CAST(1 AS STRING) AS BOOLEAN) <=> true):boolean> +-- !query 44 output +true + + +-- !query 45 +SELECT cast('1' as binary) <=> true FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <=> true)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> true)' (binary and boolean).; line 1 pos 7 + + +-- !query 46 +SELECT cast(1 as boolean) <=> true FROM t +-- !query 46 schema +struct<(CAST(1 AS BOOLEAN) <=> true):boolean> +-- !query 46 output +true + + +-- !query 47 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> true FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> true)' (timestamp and boolean).; line 1 pos 7 + + +-- !query 48 +SELECT cast('2017-12-11 09:30:00' as date) <=> true FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> true)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> true)' (date and boolean).; line 1 pos 7 + + +-- !query 49 +SELECT false = cast(0 as tinyint) FROM t +-- !query 49 schema +struct<(CAST(false AS TINYINT) = CAST(0 AS TINYINT)):boolean> +-- !query 49 output +true + + +-- !query 50 +SELECT false = cast(0 as smallint) FROM t +-- !query 50 schema +struct<(CAST(false AS SMALLINT) = CAST(0 AS SMALLINT)):boolean> +-- !query 50 output +true + + +-- !query 51 +SELECT false = cast(0 as int) FROM t +-- !query 51 schema +struct<(CAST(false AS INT) = CAST(0 AS INT)):boolean> +-- !query 51 output +true + + +-- !query 52 +SELECT false = cast(0 as bigint) FROM t +-- !query 52 schema +struct<(CAST(false AS BIGINT) = CAST(0 AS BIGINT)):boolean> +-- !query 52 output +true + + +-- !query 53 +SELECT false = cast(0 as float) FROM t +-- !query 53 schema +struct<(CAST(false AS FLOAT) = CAST(0 AS FLOAT)):boolean> +-- !query 53 output +true + + +-- !query 54 +SELECT false = cast(0 as double) FROM t +-- !query 54 schema +struct<(CAST(false AS DOUBLE) = CAST(0 AS DOUBLE)):boolean> +-- !query 54 output +true + + +-- !query 55 +SELECT false = cast(0 as decimal(10, 0)) FROM t +-- !query 55 schema +struct<(CAST(false AS DECIMAL(10,0)) = CAST(0 AS DECIMAL(10,0))):boolean> +-- !query 55 output +true + + +-- !query 56 +SELECT false = cast(0 as string) FROM t +-- !query 56 schema +struct<(false = CAST(CAST(0 AS STRING) AS BOOLEAN)):boolean> +-- !query 56 output +true + + +-- !query 57 +SELECT false = cast('0' as binary) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false = CAST('0' AS BINARY))' due to data type mismatch: differing types in '(false = CAST('0' AS BINARY))' (boolean and binary).; line 1 pos 7 + + +-- !query 58 +SELECT false = cast(0 as boolean) FROM t +-- !query 58 schema +struct<(false = CAST(0 AS BOOLEAN)):boolean> +-- !query 58 output +true + + +-- !query 59 +SELECT false = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(false = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 60 +SELECT false = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(false = CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7 + + +-- !query 61 +SELECT false <=> cast(0 as tinyint) FROM t +-- !query 61 schema +struct<(CAST(false AS TINYINT) <=> CAST(0 AS TINYINT)):boolean> +-- !query 61 output +true + + +-- !query 62 +SELECT false <=> cast(0 as smallint) FROM t +-- !query 62 schema +struct<(CAST(false AS SMALLINT) <=> CAST(0 AS SMALLINT)):boolean> +-- !query 62 output +true + + +-- !query 63 +SELECT false <=> cast(0 as int) FROM t +-- !query 63 schema +struct<(CAST(false AS INT) <=> CAST(0 AS INT)):boolean> +-- !query 63 output +true + + +-- !query 64 +SELECT false <=> cast(0 as bigint) FROM t +-- !query 64 schema +struct<(CAST(false AS BIGINT) <=> CAST(0 AS BIGINT)):boolean> +-- !query 64 output +true + + +-- !query 65 +SELECT false <=> cast(0 as float) FROM t +-- !query 65 schema +struct<(CAST(false AS FLOAT) <=> CAST(0 AS FLOAT)):boolean> +-- !query 65 output +true + + +-- !query 66 +SELECT false <=> cast(0 as double) FROM t +-- !query 66 schema +struct<(CAST(false AS DOUBLE) <=> CAST(0 AS DOUBLE)):boolean> +-- !query 66 output +true + + +-- !query 67 +SELECT false <=> cast(0 as decimal(10, 0)) FROM t +-- !query 67 schema +struct<(CAST(false AS DECIMAL(10,0)) <=> CAST(0 AS DECIMAL(10,0))):boolean> +-- !query 67 output +true + + +-- !query 68 +SELECT false <=> cast(0 as string) FROM t +-- !query 68 schema +struct<(false <=> CAST(CAST(0 AS STRING) AS BOOLEAN)):boolean> +-- !query 68 output +true + + +-- !query 69 +SELECT false <=> cast('0' as binary) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false <=> CAST('0' AS BINARY))' due to data type mismatch: differing types in '(false <=> CAST('0' AS BINARY))' (boolean and binary).; line 1 pos 7 + + +-- !query 70 +SELECT false <=> cast(0 as boolean) FROM t +-- !query 70 schema +struct<(false <=> CAST(0 AS BOOLEAN)):boolean> +-- !query 70 output +true + + +-- !query 71 +SELECT false <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(false <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 72 +SELECT false <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve '(false <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(false <=> CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7 + + +-- !query 73 +SELECT cast(0 as tinyint) = false FROM t +-- !query 73 schema +struct<(CAST(0 AS TINYINT) = CAST(false AS TINYINT)):boolean> +-- !query 73 output +true + + +-- !query 74 +SELECT cast(0 as smallint) = false FROM t +-- !query 74 schema +struct<(CAST(0 AS SMALLINT) = CAST(false AS SMALLINT)):boolean> +-- !query 74 output +true + + +-- !query 75 +SELECT cast(0 as int) = false FROM t +-- !query 75 schema +struct<(CAST(0 AS INT) = CAST(false AS INT)):boolean> +-- !query 75 output +true + + +-- !query 76 +SELECT cast(0 as bigint) = false FROM t +-- !query 76 schema +struct<(CAST(0 AS BIGINT) = CAST(false AS BIGINT)):boolean> +-- !query 76 output +true + + +-- !query 77 +SELECT cast(0 as float) = false FROM t +-- !query 77 schema +struct<(CAST(0 AS FLOAT) = CAST(false AS FLOAT)):boolean> +-- !query 77 output +true + + +-- !query 78 +SELECT cast(0 as double) = false FROM t +-- !query 78 schema +struct<(CAST(0 AS DOUBLE) = CAST(false AS DOUBLE)):boolean> +-- !query 78 output +true + + +-- !query 79 +SELECT cast(0 as decimal(10, 0)) = false FROM t +-- !query 79 schema +struct<(CAST(0 AS DECIMAL(10,0)) = CAST(false AS DECIMAL(10,0))):boolean> +-- !query 79 output +true + + +-- !query 80 +SELECT cast(0 as string) = false FROM t +-- !query 80 schema +struct<(CAST(CAST(0 AS STRING) AS BOOLEAN) = false):boolean> +-- !query 80 output +true + + +-- !query 81 +SELECT cast('0' as binary) = false FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('0' AS BINARY) = false)' due to data type mismatch: differing types in '(CAST('0' AS BINARY) = false)' (binary and boolean).; line 1 pos 7 + + +-- !query 82 +SELECT cast(0 as boolean) = false FROM t +-- !query 82 schema +struct<(CAST(0 AS BOOLEAN) = false):boolean> +-- !query 82 output +true + + +-- !query 83 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = false FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = false)' (timestamp and boolean).; line 1 pos 7 + + +-- !query 84 +SELECT cast('2017-12-11 09:30:00' as date) = false FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = false)' (date and boolean).; line 1 pos 7 + + +-- !query 85 +SELECT cast(0 as tinyint) <=> false FROM t +-- !query 85 schema +struct<(CAST(0 AS TINYINT) <=> CAST(false AS TINYINT)):boolean> +-- !query 85 output +true + + +-- !query 86 +SELECT cast(0 as smallint) <=> false FROM t +-- !query 86 schema +struct<(CAST(0 AS SMALLINT) <=> CAST(false AS SMALLINT)):boolean> +-- !query 86 output +true + + +-- !query 87 +SELECT cast(0 as int) <=> false FROM t +-- !query 87 schema +struct<(CAST(0 AS INT) <=> CAST(false AS INT)):boolean> +-- !query 87 output +true + + +-- !query 88 +SELECT cast(0 as bigint) <=> false FROM t +-- !query 88 schema +struct<(CAST(0 AS BIGINT) <=> CAST(false AS BIGINT)):boolean> +-- !query 88 output +true + + +-- !query 89 +SELECT cast(0 as float) <=> false FROM t +-- !query 89 schema +struct<(CAST(0 AS FLOAT) <=> CAST(false AS FLOAT)):boolean> +-- !query 89 output +true + + +-- !query 90 +SELECT cast(0 as double) <=> false FROM t +-- !query 90 schema +struct<(CAST(0 AS DOUBLE) <=> CAST(false AS DOUBLE)):boolean> +-- !query 90 output +true + + +-- !query 91 +SELECT cast(0 as decimal(10, 0)) <=> false FROM t +-- !query 91 schema +struct<(CAST(0 AS DECIMAL(10,0)) <=> CAST(false AS DECIMAL(10,0))):boolean> +-- !query 91 output +true + + +-- !query 92 +SELECT cast(0 as string) <=> false FROM t +-- !query 92 schema +struct<(CAST(CAST(0 AS STRING) AS BOOLEAN) <=> false):boolean> +-- !query 92 output +true + + +-- !query 93 +SELECT cast('0' as binary) <=> false FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('0' AS BINARY) <=> false)' due to data type mismatch: differing types in '(CAST('0' AS BINARY) <=> false)' (binary and boolean).; line 1 pos 7 + + +-- !query 94 +SELECT cast(0 as boolean) <=> false FROM t +-- !query 94 schema +struct<(CAST(0 AS BOOLEAN) <=> false):boolean> +-- !query 94 output +true + + +-- !query 95 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> false FROM t +-- !query 95 schema +struct<> +-- !query 95 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> false)' (timestamp and boolean).; line 1 pos 7 + + +-- !query 96 +SELECT cast('2017-12-11 09:30:00' as date) <=> false FROM t +-- !query 96 schema +struct<> +-- !query 96 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> false)' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> false)' (date and boolean).; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out new file mode 100644 index 0000000000000..017e0fea30e90 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/division.sql.out @@ -0,0 +1,1242 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 145 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT cast(1 as tinyint) / cast(1 as tinyint) FROM t +-- !query 1 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 1 output +1.0 + + +-- !query 2 +SELECT cast(1 as tinyint) / cast(1 as smallint) FROM t +-- !query 2 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 2 output +1.0 + + +-- !query 3 +SELECT cast(1 as tinyint) / cast(1 as int) FROM t +-- !query 3 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 3 output +1.0 + + +-- !query 4 +SELECT cast(1 as tinyint) / cast(1 as bigint) FROM t +-- !query 4 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 4 output +1.0 + + +-- !query 5 +SELECT cast(1 as tinyint) / cast(1 as float) FROM t +-- !query 5 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 5 output +1.0 + + +-- !query 6 +SELECT cast(1 as tinyint) / cast(1 as double) FROM t +-- !query 6 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 6 output +1.0 + + +-- !query 7 +SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t +-- !query 7 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)> +-- !query 7 output +1 + + +-- !query 8 +SELECT cast(1 as tinyint) / cast(1 as string) FROM t +-- !query 8 schema +struct<(CAST(CAST(1 AS TINYINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double> +-- !query 8 output +1.0 + + +-- !query 9 +SELECT cast(1 as tinyint) / cast('1' as binary) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('1' AS BINARY))' (tinyint and binary).; line 1 pos 7 + + +-- !query 10 +SELECT cast(1 as tinyint) / cast(1 as boolean) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST(1 AS BOOLEAN))' (tinyint and boolean).; line 1 pos 7 + + +-- !query 11 +SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (tinyint and timestamp).; line 1 pos 7 + + +-- !query 12 +SELECT cast(1 as tinyint) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) / CAST('2017-12-11 09:30:00' AS DATE))' (tinyint and date).; line 1 pos 7 + + +-- !query 13 +SELECT cast(1 as smallint) / cast(1 as tinyint) FROM t +-- !query 13 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 13 output +1.0 + + +-- !query 14 +SELECT cast(1 as smallint) / cast(1 as smallint) FROM t +-- !query 14 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 14 output +1.0 + + +-- !query 15 +SELECT cast(1 as smallint) / cast(1 as int) FROM t +-- !query 15 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 15 output +1.0 + + +-- !query 16 +SELECT cast(1 as smallint) / cast(1 as bigint) FROM t +-- !query 16 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 16 output +1.0 + + +-- !query 17 +SELECT cast(1 as smallint) / cast(1 as float) FROM t +-- !query 17 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 17 output +1.0 + + +-- !query 18 +SELECT cast(1 as smallint) / cast(1 as double) FROM t +-- !query 18 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 18 output +1.0 + + +-- !query 19 +SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t +-- !query 19 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)> +-- !query 19 output +1 + + +-- !query 20 +SELECT cast(1 as smallint) / cast(1 as string) FROM t +-- !query 20 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double> +-- !query 20 output +1.0 + + +-- !query 21 +SELECT cast(1 as smallint) / cast('1' as binary) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('1' AS BINARY))' (smallint and binary).; line 1 pos 7 + + +-- !query 22 +SELECT cast(1 as smallint) / cast(1 as boolean) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST(1 AS BOOLEAN))' (smallint and boolean).; line 1 pos 7 + + +-- !query 23 +SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (smallint and timestamp).; line 1 pos 7 + + +-- !query 24 +SELECT cast(1 as smallint) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) / CAST('2017-12-11 09:30:00' AS DATE))' (smallint and date).; line 1 pos 7 + + +-- !query 25 +SELECT cast(1 as int) / cast(1 as tinyint) FROM t +-- !query 25 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 25 output +1.0 + + +-- !query 26 +SELECT cast(1 as int) / cast(1 as smallint) FROM t +-- !query 26 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 26 output +1.0 + + +-- !query 27 +SELECT cast(1 as int) / cast(1 as int) FROM t +-- !query 27 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 27 output +1.0 + + +-- !query 28 +SELECT cast(1 as int) / cast(1 as bigint) FROM t +-- !query 28 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 28 output +1.0 + + +-- !query 29 +SELECT cast(1 as int) / cast(1 as float) FROM t +-- !query 29 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 29 output +1.0 + + +-- !query 30 +SELECT cast(1 as int) / cast(1 as double) FROM t +-- !query 30 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 30 output +1.0 + + +-- !query 31 +SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t +-- !query 31 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)> +-- !query 31 output +1 + + +-- !query 32 +SELECT cast(1 as int) / cast(1 as string) FROM t +-- !query 32 schema +struct<(CAST(CAST(1 AS INT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double> +-- !query 32 output +1.0 + + +-- !query 33 +SELECT cast(1 as int) / cast('1' as binary) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('1' AS BINARY))' (int and binary).; line 1 pos 7 + + +-- !query 34 +SELECT cast(1 as int) / cast(1 as boolean) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST(1 AS BOOLEAN))' (int and boolean).; line 1 pos 7 + + +-- !query 35 +SELECT cast(1 as int) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (int and timestamp).; line 1 pos 7 + + +-- !query 36 +SELECT cast(1 as int) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS INT) / CAST('2017-12-11 09:30:00' AS DATE))' (int and date).; line 1 pos 7 + + +-- !query 37 +SELECT cast(1 as bigint) / cast(1 as tinyint) FROM t +-- !query 37 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 37 output +1.0 + + +-- !query 38 +SELECT cast(1 as bigint) / cast(1 as smallint) FROM t +-- !query 38 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 38 output +1.0 + + +-- !query 39 +SELECT cast(1 as bigint) / cast(1 as int) FROM t +-- !query 39 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 39 output +1.0 + + +-- !query 40 +SELECT cast(1 as bigint) / cast(1 as bigint) FROM t +-- !query 40 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 40 output +1.0 + + +-- !query 41 +SELECT cast(1 as bigint) / cast(1 as float) FROM t +-- !query 41 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 41 output +1.0 + + +-- !query 42 +SELECT cast(1 as bigint) / cast(1 as double) FROM t +-- !query 42 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 42 output +1.0 + + +-- !query 43 +SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t +-- !query 43 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)> +-- !query 43 output +1 + + +-- !query 44 +SELECT cast(1 as bigint) / cast(1 as string) FROM t +-- !query 44 schema +struct<(CAST(CAST(1 AS BIGINT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double> +-- !query 44 output +1.0 + + +-- !query 45 +SELECT cast(1 as bigint) / cast('1' as binary) FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('1' AS BINARY))' (bigint and binary).; line 1 pos 7 + + +-- !query 46 +SELECT cast(1 as bigint) / cast(1 as boolean) FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST(1 AS BOOLEAN))' (bigint and boolean).; line 1 pos 7 + + +-- !query 47 +SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (bigint and timestamp).; line 1 pos 7 + + +-- !query 48 +SELECT cast(1 as bigint) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) / CAST('2017-12-11 09:30:00' AS DATE))' (bigint and date).; line 1 pos 7 + + +-- !query 49 +SELECT cast(1 as float) / cast(1 as tinyint) FROM t +-- !query 49 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 49 output +1.0 + + +-- !query 50 +SELECT cast(1 as float) / cast(1 as smallint) FROM t +-- !query 50 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 50 output +1.0 + + +-- !query 51 +SELECT cast(1 as float) / cast(1 as int) FROM t +-- !query 51 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 51 output +1.0 + + +-- !query 52 +SELECT cast(1 as float) / cast(1 as bigint) FROM t +-- !query 52 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 52 output +1.0 + + +-- !query 53 +SELECT cast(1 as float) / cast(1 as float) FROM t +-- !query 53 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 53 output +1.0 + + +-- !query 54 +SELECT cast(1 as float) / cast(1 as double) FROM t +-- !query 54 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(1 AS DOUBLE) AS DOUBLE)):double> +-- !query 54 output +1.0 + + +-- !query 55 +SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t +-- !query 55 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) AS DOUBLE)):double> +-- !query 55 output +1.0 + + +-- !query 56 +SELECT cast(1 as float) / cast(1 as string) FROM t +-- !query 56 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS STRING) AS DOUBLE) AS DOUBLE)):double> +-- !query 56 output +1.0 + + +-- !query 57 +SELECT cast(1 as float) / cast('1' as binary) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('1' AS BINARY))' (float and binary).; line 1 pos 7 + + +-- !query 58 +SELECT cast(1 as float) / cast(1 as boolean) FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST(1 AS BOOLEAN))' (float and boolean).; line 1 pos 7 + + +-- !query 59 +SELECT cast(1 as float) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (float and timestamp).; line 1 pos 7 + + +-- !query 60 +SELECT cast(1 as float) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) / CAST('2017-12-11 09:30:00' AS DATE))' (float and date).; line 1 pos 7 + + +-- !query 61 +SELECT cast(1 as double) / cast(1 as tinyint) FROM t +-- !query 61 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 61 output +1.0 + + +-- !query 62 +SELECT cast(1 as double) / cast(1 as smallint) FROM t +-- !query 62 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 62 output +1.0 + + +-- !query 63 +SELECT cast(1 as double) / cast(1 as int) FROM t +-- !query 63 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 63 output +1.0 + + +-- !query 64 +SELECT cast(1 as double) / cast(1 as bigint) FROM t +-- !query 64 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 64 output +1.0 + + +-- !query 65 +SELECT cast(1 as double) / cast(1 as float) FROM t +-- !query 65 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 65 output +1.0 + + +-- !query 66 +SELECT cast(1 as double) / cast(1 as double) FROM t +-- !query 66 schema +struct<(CAST(1 AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 66 output +1.0 + + +-- !query 67 +SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t +-- !query 67 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 67 output +1.0 + + +-- !query 68 +SELECT cast(1 as double) / cast(1 as string) FROM t +-- !query 68 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 68 output +1.0 + + +-- !query 69 +SELECT cast(1 as double) / cast('1' as binary) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 70 +SELECT cast(1 as double) / cast(1 as boolean) FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 71 +SELECT cast(1 as double) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 72 +SELECT cast(1 as double) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 73 +SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t +-- !query 73 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 73 output +1 + + +-- !query 74 +SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t +-- !query 74 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 74 output +1 + + +-- !query 75 +SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t +-- !query 75 schema +struct<(CAST(1 AS DECIMAL(10,0)) / CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,11)> +-- !query 75 output +1 + + +-- !query 76 +SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t +-- !query 76 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)> +-- !query 76 output +1 + + +-- !query 77 +SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t +-- !query 77 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 77 output +1.0 + + +-- !query 78 +SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t +-- !query 78 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 78 output +1.0 + + +-- !query 79 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 79 schema +struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)> +-- !query 79 output +1 + + +-- !query 80 +SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t +-- !query 80 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 80 output +1.0 + + +-- !query 81 +SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 82 +SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 83 +SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 84 +SELECT cast(1 as decimal(10, 0)) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 85 +SELECT cast(1 as string) / cast(1 as tinyint) FROM t +-- !query 85 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS TINYINT) AS DOUBLE)):double> +-- !query 85 output +1.0 + + +-- !query 86 +SELECT cast(1 as string) / cast(1 as smallint) FROM t +-- !query 86 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS SMALLINT) AS DOUBLE)):double> +-- !query 86 output +1.0 + + +-- !query 87 +SELECT cast(1 as string) / cast(1 as int) FROM t +-- !query 87 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS INT) AS DOUBLE)):double> +-- !query 87 output +1.0 + + +-- !query 88 +SELECT cast(1 as string) / cast(1 as bigint) FROM t +-- !query 88 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS BIGINT) AS DOUBLE)):double> +-- !query 88 output +1.0 + + +-- !query 89 +SELECT cast(1 as string) / cast(1 as float) FROM t +-- !query 89 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 89 output +1.0 + + +-- !query 90 +SELECT cast(1 as string) / cast(1 as double) FROM t +-- !query 90 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 90 output +1.0 + + +-- !query 91 +SELECT cast(1 as string) / cast(1 as decimal(10, 0)) FROM t +-- !query 91 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 91 output +1.0 + + +-- !query 92 +SELECT cast(1 as string) / cast(1 as string) FROM t +-- !query 92 schema +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 92 output +1.0 + + +-- !query 93 +SELECT cast(1 as string) / cast('1' as binary) FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('1' AS BINARY))' (double and binary).; line 1 pos 7 + + +-- !query 94 +SELECT cast(1 as string) / cast(1 as boolean) FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST(1 AS BOOLEAN))' (double and boolean).; line 1 pos 7 + + +-- !query 95 +SELECT cast(1 as string) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 95 schema +struct<> +-- !query 95 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (double and timestamp).; line 1 pos 7 + + +-- !query 96 +SELECT cast(1 as string) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 96 schema +struct<> +-- !query 96 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(CAST(1 AS STRING) AS DOUBLE) / CAST('2017-12-11 09:30:00' AS DATE))' (double and date).; line 1 pos 7 + + +-- !query 97 +SELECT cast('1' as binary) / cast(1 as tinyint) FROM t +-- !query 97 schema +struct<> +-- !query 97 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS TINYINT))' (binary and tinyint).; line 1 pos 7 + + +-- !query 98 +SELECT cast('1' as binary) / cast(1 as smallint) FROM t +-- !query 98 schema +struct<> +-- !query 98 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS SMALLINT))' (binary and smallint).; line 1 pos 7 + + +-- !query 99 +SELECT cast('1' as binary) / cast(1 as int) FROM t +-- !query 99 schema +struct<> +-- !query 99 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS INT))' (binary and int).; line 1 pos 7 + + +-- !query 100 +SELECT cast('1' as binary) / cast(1 as bigint) FROM t +-- !query 100 schema +struct<> +-- !query 100 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS BIGINT))' (binary and bigint).; line 1 pos 7 + + +-- !query 101 +SELECT cast('1' as binary) / cast(1 as float) FROM t +-- !query 101 schema +struct<> +-- !query 101 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS FLOAT))' (binary and float).; line 1 pos 7 + + +-- !query 102 +SELECT cast('1' as binary) / cast(1 as double) FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 103 +SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 104 +SELECT cast('1' as binary) / cast(1 as string) FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(CAST(1 AS STRING) AS DOUBLE))' (binary and double).; line 1 pos 7 + + +-- !query 105 +SELECT cast('1' as binary) / cast('1' as binary) FROM t +-- !query 105 schema +struct<> +-- !query 105 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST('1' AS BINARY))' due to data type mismatch: '(CAST('1' AS BINARY) / CAST('1' AS BINARY))' requires (double or decimal) type, not binary; line 1 pos 7 + + +-- !query 106 +SELECT cast('1' as binary) / cast(1 as boolean) FROM t +-- !query 106 schema +struct<> +-- !query 106 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS BOOLEAN))' (binary and boolean).; line 1 pos 7 + + +-- !query 107 +SELECT cast('1' as binary) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 107 schema +struct<> +-- !query 107 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (binary and timestamp).; line 1 pos 7 + + +-- !query 108 +SELECT cast('1' as binary) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 108 schema +struct<> +-- !query 108 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST('2017-12-11 09:30:00' AS DATE))' (binary and date).; line 1 pos 7 + + +-- !query 109 +SELECT cast(1 as boolean) / cast(1 as tinyint) FROM t +-- !query 109 schema +struct<> +-- !query 109 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS TINYINT))' (boolean and tinyint).; line 1 pos 7 + + +-- !query 110 +SELECT cast(1 as boolean) / cast(1 as smallint) FROM t +-- !query 110 schema +struct<> +-- !query 110 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS SMALLINT))' (boolean and smallint).; line 1 pos 7 + + +-- !query 111 +SELECT cast(1 as boolean) / cast(1 as int) FROM t +-- !query 111 schema +struct<> +-- !query 111 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS INT))' (boolean and int).; line 1 pos 7 + + +-- !query 112 +SELECT cast(1 as boolean) / cast(1 as bigint) FROM t +-- !query 112 schema +struct<> +-- !query 112 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS BIGINT))' (boolean and bigint).; line 1 pos 7 + + +-- !query 113 +SELECT cast(1 as boolean) / cast(1 as float) FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS FLOAT))' (boolean and float).; line 1 pos 7 + + +-- !query 114 +SELECT cast(1 as boolean) / cast(1 as double) FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 115 +SELECT cast(1 as boolean) / cast(1 as decimal(10, 0)) FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(1 AS DECIMAL(10,0)))' (boolean and decimal(10,0)).; line 1 pos 7 + + +-- !query 116 +SELECT cast(1 as boolean) / cast(1 as string) FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST(CAST(1 AS STRING) AS DOUBLE))' (boolean and double).; line 1 pos 7 + + +-- !query 117 +SELECT cast(1 as boolean) / cast('1' as binary) FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('1' AS BINARY))' (boolean and binary).; line 1 pos 7 + + +-- !query 118 +SELECT cast(1 as boolean) / cast(1 as boolean) FROM t +-- !query 118 schema +struct<> +-- !query 118 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST(1 AS BOOLEAN))' due to data type mismatch: '(CAST(1 AS BOOLEAN) / CAST(1 AS BOOLEAN))' requires (double or decimal) type, not boolean; line 1 pos 7 + + +-- !query 119 +SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (boolean and timestamp).; line 1 pos 7 + + +-- !query 120 +SELECT cast(1 as boolean) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) / CAST('2017-12-11 09:30:00' AS DATE))' (boolean and date).; line 1 pos 7 + + +-- !query 121 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as tinyint) FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS TINYINT))' (timestamp and tinyint).; line 1 pos 7 + + +-- !query 122 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as smallint) FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS SMALLINT))' (timestamp and smallint).; line 1 pos 7 + + +-- !query 123 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as int) FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS INT))' (timestamp and int).; line 1 pos 7 + + +-- !query 124 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as bigint) FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BIGINT))' (timestamp and bigint).; line 1 pos 7 + + +-- !query 125 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as float) FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS FLOAT))' (timestamp and float).; line 1 pos 7 + + +-- !query 126 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as double) FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 127 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 128 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as string) FROM t +-- !query 128 schema +struct<> +-- !query 128 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(CAST(1 AS STRING) AS DOUBLE))' (timestamp and double).; line 1 pos 7 + + +-- !query 129 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('1' as binary) FROM t +-- !query 129 schema +struct<> +-- !query 129 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('1' AS BINARY))' (timestamp and binary).; line 1 pos 7 + + +-- !query 130 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast(1 as boolean) FROM t +-- !query 130 schema +struct<> +-- !query 130 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS BOOLEAN))' (timestamp and boolean).; line 1 pos 7 + + +-- !query 131 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 131 schema +struct<> +-- !query 131 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' requires (double or decimal) type, not timestamp; line 1 pos 7 + + +-- !query 132 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 132 schema +struct<> +-- !query 132 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) / CAST('2017-12-11 09:30:00' AS DATE))' (timestamp and date).; line 1 pos 7 + + +-- !query 133 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as tinyint) FROM t +-- !query 133 schema +struct<> +-- !query 133 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS TINYINT))' (date and tinyint).; line 1 pos 7 + + +-- !query 134 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as smallint) FROM t +-- !query 134 schema +struct<> +-- !query 134 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS SMALLINT))' (date and smallint).; line 1 pos 7 + + +-- !query 135 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as int) FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS INT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS INT))' (date and int).; line 1 pos 7 + + +-- !query 136 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as bigint) FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BIGINT))' (date and bigint).; line 1 pos 7 + + +-- !query 137 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as float) FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS FLOAT))' (date and float).; line 1 pos 7 + + +-- !query 138 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as double) FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 139 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t +-- !query 139 schema +struct<> +-- !query 139 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 140 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as string) FROM t +-- !query 140 schema +struct<> +-- !query 140 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(CAST(1 AS STRING) AS DOUBLE))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(CAST(1 AS STRING) AS DOUBLE))' (date and double).; line 1 pos 7 + + +-- !query 141 +SELECT cast('2017-12-11 09:30:00' as date) / cast('1' as binary) FROM t +-- !query 141 schema +struct<> +-- !query 141 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('1' AS BINARY))' (date and binary).; line 1 pos 7 + + +-- !query 142 +SELECT cast('2017-12-11 09:30:00' as date) / cast(1 as boolean) FROM t +-- !query 142 schema +struct<> +-- !query 142 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST(1 AS BOOLEAN))' (date and boolean).; line 1 pos 7 + + +-- !query 143 +SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 143 schema +struct<> +-- !query 143 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (date and timestamp).; line 1 pos 7 + + +-- !query 144 +SELECT cast('2017-12-11 09:30:00' as date) / cast('2017-12-11 09:30:00' as date) FROM t +-- !query 144 schema +struct<> +-- !query 144 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: '(CAST('2017-12-11 09:30:00' AS DATE) / CAST('2017-12-11 09:30:00' AS DATE))' requires (double or decimal) type, not date; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out new file mode 100644 index 0000000000000..20a9e47217238 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out @@ -0,0 +1,1305 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 145 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 1 schema +struct +-- !query 1 output +1 +2 + + +-- !query 2 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 2 schema +struct +-- !query 2 output +1 +2 + + +-- !query 3 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 3 schema +struct +-- !query 3 output +1 +2 + + +-- !query 4 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 4 schema +struct +-- !query 4 output +1 +2 + + +-- !query 5 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 5 schema +struct +-- !query 5 output +1.0 +2.0 + + +-- !query 6 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 6 schema +struct +-- !query 6 output +1.0 +2.0 + + +-- !query 7 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 7 schema +struct +-- !query 7 output +1 +2 + + +-- !query 8 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 8 schema +struct +-- !query 8 output +1 +2 + + +-- !query 9 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> tinyint at the first column of the second table; + + +-- !query 10 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> tinyint at the first column of the second table; + + +-- !query 11 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> tinyint at the first column of the second table; + + +-- !query 12 +SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> tinyint at the first column of the second table; + + +-- !query 13 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 13 schema +struct +-- !query 13 output +1 +2 + + +-- !query 14 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 14 schema +struct +-- !query 14 output +1 +2 + + +-- !query 15 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 15 schema +struct +-- !query 15 output +1 +2 + + +-- !query 16 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 16 schema +struct +-- !query 16 output +1 +2 + + +-- !query 17 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 17 schema +struct +-- !query 17 output +1.0 +2.0 + + +-- !query 18 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 18 schema +struct +-- !query 18 output +1.0 +2.0 + + +-- !query 19 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 19 schema +struct +-- !query 19 output +1 +2 + + +-- !query 20 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 20 schema +struct +-- !query 20 output +1 +2 + + +-- !query 21 +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> smallint at the first column of the second table; + + +-- !query 22 +SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> smallint at the first column of the second table; + + +-- !query 23 +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> smallint at the first column of the second table; + + +-- !query 24 +SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> smallint at the first column of the second table; + + +-- !query 25 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 25 schema +struct +-- !query 25 output +1 +2 + + +-- !query 26 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 26 schema +struct +-- !query 26 output +1 +2 + + +-- !query 27 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 27 schema +struct +-- !query 27 output +1 +2 + + +-- !query 28 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 28 schema +struct +-- !query 28 output +1 +2 + + +-- !query 29 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 29 schema +struct +-- !query 29 output +1.0 +2.0 + + +-- !query 30 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 30 schema +struct +-- !query 30 output +1.0 +2.0 + + +-- !query 31 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 31 schema +struct +-- !query 31 output +1 +2 + + +-- !query 32 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 32 schema +struct +-- !query 32 output +1 +2 + + +-- !query 33 +SELECT cast(1 as int) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> int at the first column of the second table; + + +-- !query 34 +SELECT cast(1 as int) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> int at the first column of the second table; + + +-- !query 35 +SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> int at the first column of the second table; + + +-- !query 36 +SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> int at the first column of the second table; + + +-- !query 37 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 37 schema +struct +-- !query 37 output +1 +2 + + +-- !query 38 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 38 schema +struct +-- !query 38 output +1 +2 + + +-- !query 39 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 39 schema +struct +-- !query 39 output +1 +2 + + +-- !query 40 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 40 schema +struct +-- !query 40 output +1 +2 + + +-- !query 41 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 41 schema +struct +-- !query 41 output +1.0 +2.0 + + +-- !query 42 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 42 schema +struct +-- !query 42 output +1.0 +2.0 + + +-- !query 43 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 43 schema +struct +-- !query 43 output +1 +2 + + +-- !query 44 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 44 schema +struct +-- !query 44 output +1 +2 + + +-- !query 45 +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 45 schema +struct<> +-- !query 45 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> bigint at the first column of the second table; + + +-- !query 46 +SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 46 schema +struct<> +-- !query 46 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> bigint at the first column of the second table; + + +-- !query 47 +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 47 schema +struct<> +-- !query 47 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> bigint at the first column of the second table; + + +-- !query 48 +SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> bigint at the first column of the second table; + + +-- !query 49 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 49 schema +struct +-- !query 49 output +1.0 +2.0 + + +-- !query 50 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 50 schema +struct +-- !query 50 output +1.0 +2.0 + + +-- !query 51 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 51 schema +struct +-- !query 51 output +1.0 +2.0 + + +-- !query 52 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 52 schema +struct +-- !query 52 output +1.0 +2.0 + + +-- !query 53 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 53 schema +struct +-- !query 53 output +1.0 +2.0 + + +-- !query 54 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 54 schema +struct +-- !query 54 output +1.0 +2.0 + + +-- !query 55 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 55 schema +struct +-- !query 55 output +1.0 +2.0 + + +-- !query 56 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 56 schema +struct +-- !query 56 output +1.0 +2 + + +-- !query 57 +SELECT cast(1 as float) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 57 schema +struct<> +-- !query 57 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> float at the first column of the second table; + + +-- !query 58 +SELECT cast(1 as float) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 58 schema +struct<> +-- !query 58 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> float at the first column of the second table; + + +-- !query 59 +SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 59 schema +struct<> +-- !query 59 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> float at the first column of the second table; + + +-- !query 60 +SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 60 schema +struct<> +-- !query 60 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> float at the first column of the second table; + + +-- !query 61 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 61 schema +struct +-- !query 61 output +1.0 +2.0 + + +-- !query 62 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 62 schema +struct +-- !query 62 output +1.0 +2.0 + + +-- !query 63 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 63 schema +struct +-- !query 63 output +1.0 +2.0 + + +-- !query 64 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 64 schema +struct +-- !query 64 output +1.0 +2.0 + + +-- !query 65 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 65 schema +struct +-- !query 65 output +1.0 +2.0 + + +-- !query 66 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 66 schema +struct +-- !query 66 output +1.0 +2.0 + + +-- !query 67 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 67 schema +struct +-- !query 67 output +1.0 +2.0 + + +-- !query 68 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 68 schema +struct +-- !query 68 output +1.0 +2 + + +-- !query 69 +SELECT cast(1 as double) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> double at the first column of the second table; + + +-- !query 70 +SELECT cast(1 as double) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 70 schema +struct<> +-- !query 70 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> double at the first column of the second table; + + +-- !query 71 +SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> double at the first column of the second table; + + +-- !query 72 +SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 72 schema +struct<> +-- !query 72 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> double at the first column of the second table; + + +-- !query 73 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 73 schema +struct +-- !query 73 output +1 +2 + + +-- !query 74 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 74 schema +struct +-- !query 74 output +1 +2 + + +-- !query 75 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 75 schema +struct +-- !query 75 output +1 +2 + + +-- !query 76 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 76 schema +struct +-- !query 76 output +1 +2 + + +-- !query 77 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 77 schema +struct +-- !query 77 output +1.0 +2.0 + + +-- !query 78 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 78 schema +struct +-- !query 78 output +1.0 +2.0 + + +-- !query 79 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 79 schema +struct +-- !query 79 output +1 +2 + + +-- !query 80 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 80 schema +struct +-- !query 80 output +1 +2 + + +-- !query 81 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> decimal(10,0) at the first column of the second table; + + +-- !query 82 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> decimal(10,0) at the first column of the second table; + + +-- !query 83 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> decimal(10,0) at the first column of the second table; + + +-- !query 84 +SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> decimal(10,0) at the first column of the second table; + + +-- !query 85 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 85 schema +struct +-- !query 85 output +1 +2 + + +-- !query 86 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 86 schema +struct +-- !query 86 output +1 +2 + + +-- !query 87 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 87 schema +struct +-- !query 87 output +1 +2 + + +-- !query 88 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 88 schema +struct +-- !query 88 output +1 +2 + + +-- !query 89 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 89 schema +struct +-- !query 89 output +1 +2.0 + + +-- !query 90 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 90 schema +struct +-- !query 90 output +1 +2.0 + + +-- !query 91 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 91 schema +struct +-- !query 91 output +1 +2 + + +-- !query 92 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 92 schema +struct +-- !query 92 output +1 +2 + + +-- !query 93 +SELECT cast(1 as string) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 93 schema +struct<> +-- !query 93 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> string at the first column of the second table; + + +-- !query 94 +SELECT cast(1 as string) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 94 schema +struct<> +-- !query 94 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> string at the first column of the second table; + + +-- !query 95 +SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 95 schema +struct +-- !query 95 output +1 +2017-12-11 09:30:00 + + +-- !query 96 +SELECT cast(1 as string) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 96 schema +struct +-- !query 96 output +1 +2017-12-11 + + +-- !query 97 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 97 schema +struct<> +-- !query 97 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. tinyint <> binary at the first column of the second table; + + +-- !query 98 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 98 schema +struct<> +-- !query 98 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. smallint <> binary at the first column of the second table; + + +-- !query 99 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 99 schema +struct<> +-- !query 99 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. int <> binary at the first column of the second table; + + +-- !query 100 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 100 schema +struct<> +-- !query 100 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. bigint <> binary at the first column of the second table; + + +-- !query 101 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 101 schema +struct<> +-- !query 101 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. float <> binary at the first column of the second table; + + +-- !query 102 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 102 schema +struct<> +-- !query 102 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. double <> binary at the first column of the second table; + + +-- !query 103 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 103 schema +struct<> +-- !query 103 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. decimal(10,0) <> binary at the first column of the second table; + + +-- !query 104 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 104 schema +struct<> +-- !query 104 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. string <> binary at the first column of the second table; + + +-- !query 105 +SELECT cast('1' as binary) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 105 schema +struct +-- !query 105 output +1 +2 + + +-- !query 106 +SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 106 schema +struct<> +-- !query 106 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> binary at the first column of the second table; + + +-- !query 107 +SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 107 schema +struct<> +-- !query 107 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> binary at the first column of the second table; + + +-- !query 108 +SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 108 schema +struct<> +-- !query 108 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> binary at the first column of the second table; + + +-- !query 109 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 109 schema +struct<> +-- !query 109 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. tinyint <> boolean at the first column of the second table; + + +-- !query 110 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 110 schema +struct<> +-- !query 110 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. smallint <> boolean at the first column of the second table; + + +-- !query 111 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 111 schema +struct<> +-- !query 111 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. int <> boolean at the first column of the second table; + + +-- !query 112 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 112 schema +struct<> +-- !query 112 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. bigint <> boolean at the first column of the second table; + + +-- !query 113 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 113 schema +struct<> +-- !query 113 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. float <> boolean at the first column of the second table; + + +-- !query 114 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 114 schema +struct<> +-- !query 114 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. double <> boolean at the first column of the second table; + + +-- !query 115 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 115 schema +struct<> +-- !query 115 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. decimal(10,0) <> boolean at the first column of the second table; + + +-- !query 116 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 116 schema +struct<> +-- !query 116 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. string <> boolean at the first column of the second table; + + +-- !query 117 +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> boolean at the first column of the second table; + + +-- !query 118 +SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 118 schema +struct +-- !query 118 output +true + + +-- !query 119 +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. timestamp <> boolean at the first column of the second table; + + +-- !query 120 +SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. date <> boolean at the first column of the second table; + + +-- !query 121 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. tinyint <> timestamp at the first column of the second table; + + +-- !query 122 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. smallint <> timestamp at the first column of the second table; + + +-- !query 123 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. int <> timestamp at the first column of the second table; + + +-- !query 124 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. bigint <> timestamp at the first column of the second table; + + +-- !query 125 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. float <> timestamp at the first column of the second table; + + +-- !query 126 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. double <> timestamp at the first column of the second table; + + +-- !query 127 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. decimal(10,0) <> timestamp at the first column of the second table; + + +-- !query 128 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 128 schema +struct +-- !query 128 output +2 +2017-12-12 09:30:00 + + +-- !query 129 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 129 schema +struct<> +-- !query 129 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> timestamp at the first column of the second table; + + +-- !query 130 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 130 schema +struct<> +-- !query 130 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> timestamp at the first column of the second table; + + +-- !query 131 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 131 schema +struct +-- !query 131 output +2017-12-11 09:30:00 +2017-12-12 09:30:00 + + +-- !query 132 +SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 132 schema +struct +-- !query 132 output +2017-12-11 00:00:00 +2017-12-12 09:30:00 + + +-- !query 133 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as tinyint) FROM t +-- !query 133 schema +struct<> +-- !query 133 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. tinyint <> date at the first column of the second table; + + +-- !query 134 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as smallint) FROM t +-- !query 134 schema +struct<> +-- !query 134 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. smallint <> date at the first column of the second table; + + +-- !query 135 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as int) FROM t +-- !query 135 schema +struct<> +-- !query 135 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. int <> date at the first column of the second table; + + +-- !query 136 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as bigint) FROM t +-- !query 136 schema +struct<> +-- !query 136 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. bigint <> date at the first column of the second table; + + +-- !query 137 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as float) FROM t +-- !query 137 schema +struct<> +-- !query 137 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. float <> date at the first column of the second table; + + +-- !query 138 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as double) FROM t +-- !query 138 schema +struct<> +-- !query 138 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. double <> date at the first column of the second table; + + +-- !query 139 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t +-- !query 139 schema +struct<> +-- !query 139 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. decimal(10,0) <> date at the first column of the second table; + + +-- !query 140 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as string) FROM t +-- !query 140 schema +struct +-- !query 140 output +2 +2017-12-12 + + +-- !query 141 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2' as binary) FROM t +-- !query 141 schema +struct<> +-- !query 141 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. binary <> date at the first column of the second table; + + +-- !query 142 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as boolean) FROM t +-- !query 142 schema +struct<> +-- !query 142 output +org.apache.spark.sql.AnalysisException +Union can only be performed on tables with the compatible column types. boolean <> date at the first column of the second table; + + +-- !query 143 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 143 schema +struct +-- !query 143 output +2017-12-11 09:30:00 +2017-12-12 00:00:00 + + +-- !query 144 +SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FROM t +-- !query 144 schema +struct +-- !query 144 output +2017-12-11 +2017-12-12 From 3a7494dfee714510f6a62d5533023f3e0d5ccdcd Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Wed, 20 Dec 2017 12:21:00 +0800 Subject: [PATCH 457/712] [SPARK-22827][CORE] Avoid throwing OutOfMemoryError in case of exception in spill ## What changes were proposed in this pull request? Currently, the task memory manager throws an OutofMemory error when there is an IO exception happens in spill() - https://github.com/apache/spark/blob/master/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java#L194. Similarly there any many other places in code when if a task is not able to acquire memory due to an exception we throw an OutofMemory error which kills the entire executor and hence failing all the tasks that are running on that executor instead of just failing one single task. ## How was this patch tested? Unit tests Author: Sital Kedia Closes #20014 from sitalkedia/skedia/upstream_SPARK-22827. --- .../apache/spark/memory/MemoryConsumer.java | 4 +-- .../spark/memory/SparkOutOfMemoryError.java | 36 +++++++++++++++++++ .../spark/memory/TaskMemoryManager.java | 4 +-- .../shuffle/sort/ShuffleExternalSorter.java | 3 +- .../unsafe/sort/UnsafeExternalSorter.java | 3 +- .../unsafe/sort/UnsafeInMemorySorter.java | 3 +- .../org/apache/spark/executor/Executor.scala | 5 ++- .../TungstenAggregationIterator.scala | 3 +- 8 files changed, 50 insertions(+), 11 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 2dff241900e82..a7bd4b3799a25 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -88,7 +88,7 @@ public void spill() throws IOException { * `LongArray` is too large to fit in a single page. The caller side should take care of these * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. * - * @throws OutOfMemoryError + * @throws SparkOutOfMemoryError * @throws TooLargePageException */ public LongArray allocateArray(long size) { @@ -154,6 +154,6 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); } } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java new file mode 100644 index 0000000000000..ca00ca58e9713 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.memory; + +import org.apache.spark.annotation.Private; + +/** + * This exception is thrown when a task can not acquire memory from the Memory manager. + * Instead of throwing {@link OutOfMemoryError}, which kills the executor, + * we should use throw this exception, which just kills the current task. + */ +@Private +public final class SparkOutOfMemoryError extends OutOfMemoryError { + + public SparkOutOfMemoryError(String s) { + super(s); + } + + public SparkOutOfMemoryError(OutOfMemoryError e) { + super(e.getMessage()); + } +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index f6b5ea3c0ad26..e8d3730daa7a4 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -192,7 +192,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); - throw new OutOfMemoryError("error while calling spill() on " + c + " : " + throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : " + e.getMessage()); } } @@ -213,7 +213,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); - throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " + throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : " + e.getMessage()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index e80f9734ecf7b..c3a07b2abf896 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -33,6 +33,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; @@ -337,7 +338,7 @@ private void growPointerArrayIfNecessary() throws IOException { // The pointer array is too big to fix in a single page, spill. spill(); return; - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.error("Unable to grow the pointer array"); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8b8e15e3f78ed..66118f454159b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -25,6 +25,7 @@ import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -349,7 +350,7 @@ private void growPointerArrayIfNecessary() throws IOException { // The pointer array is too big to fix in a single page, spill. spill(); return; - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { // should have trigger spilling if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.error("Unable to grow the pointer array"); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 3bb87a6ed653d..951d076420ee6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -24,6 +24,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; @@ -212,7 +213,7 @@ public boolean hasSpaceForAnotherRecord() { public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { - throw new OutOfMemoryError("Not enough memory to grow pointer array"); + throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } Platform.copyMemory( array.getBaseObject(), diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index af0a0ab656564..2c3a8ef74800b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,7 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} import org.apache.spark.shuffle.FetchFailedException @@ -553,10 +553,9 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. - if (Utils.isFatalError(t)) { + if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) { uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } - } finally { runningTasks.remove(taskId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 756eeb642e2d0..9dc334c1ead3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.TaskContext import org.apache.spark.internal.Logging +import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -205,7 +206,7 @@ class TungstenAggregationIterator( buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation") + throw new SparkOutOfMemoryError("No enough memory for aggregation") } } processRow(buffer, newInput) From 6e36d8d56279a2c5c92c8df8e89ee99b514817e7 Mon Sep 17 00:00:00 2001 From: Youngbin Kim Date: Tue, 19 Dec 2017 20:22:33 -0800 Subject: [PATCH 458/712] [SPARK-22829] Add new built-in function date_trunc() ## What changes were proposed in this pull request? Adding date_trunc() as a built-in function. `date_trunc` is common in other databases, but Spark or Hive does not have support for this. `date_trunc` is commonly used by data scientists and business intelligence application such as Superset (https://github.com/apache/incubator-superset). We do have `trunc` but this only works with 'MONTH' and 'YEAR' level on the DateType input. date_trunc() in other databases: AWS Redshift: http://docs.aws.amazon.com/redshift/latest/dg/r_DATE_TRUNC.html PostgreSQL: https://www.postgresql.org/docs/9.1/static/functions-datetime.html Presto: https://prestodb.io/docs/current/functions/datetime.html ## How was this patch tested? Unit tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Youngbin Kim Closes #20015 from youngbink/date_trunc. --- python/pyspark/sql/functions.py | 20 ++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 170 ++++++++++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 102 +++++++++-- .../expressions/DateExpressionsSuite.scala | 73 +++++++- .../catalyst/util/DateTimeUtilsSuite.scala | 70 ++++++++ .../org/apache/spark/sql/functions.scala | 15 ++ .../apache/spark/sql/DateFunctionsSuite.scala | 46 +++++ 8 files changed, 445 insertions(+), 52 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e0faddb1c0df..54530055dfa85 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1099,7 +1099,7 @@ def trunc(date, format): """ Returns date truncated to the unit specified by the format. - :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + :param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm' >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) >>> df.select(trunc(df.d, 'year').alias('year')).collect() @@ -1111,6 +1111,24 @@ def trunc(date, format): return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) +@since(2.3) +def date_trunc(format, timestamp): + """ + Returns timestamp truncated to the unit specified by the format. + + :param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm', + 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter' + + >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t']) + >>> df.select(date_trunc('year', df.t).alias('year')).collect() + [Row(year=datetime.datetime(1997, 1, 1, 0, 0))] + >>> df.select(date_trunc('mon', df.t).alias('month')).collect() + [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp))) + + @since(1.5) def next_day(date, dayOfWeek): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 11538bd31b4fd..5ddb39822617d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -392,6 +392,7 @@ object FunctionRegistry { expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), + expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), expression[WeekOfYear]("weekofyear"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index cfec7f82951a7..59c3e3d9947a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1294,80 +1294,79 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: override def dataType: DataType = TimestampType } -/** - * Returns date truncated to the unit specified by the format. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.", - examples = """ - Examples: - > SELECT _FUNC_('2009-02-12', 'MM'); - 2009-02-01 - > SELECT _FUNC_('2015-10-27', 'YEAR'); - 2015-01-01 - """, - since = "1.5.0") -// scalastyle:on line.size.limit -case class TruncDate(date: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = date - override def right: Expression = format - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) - override def dataType: DataType = DateType +trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { + val instant: Expression + val format: Expression override def nullable: Boolean = true - override def prettyName: String = "trunc" private lazy val truncLevel: Int = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - override def eval(input: InternalRow): Any = { + /** + * @param input internalRow (time) + * @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input) + * @param truncFunc function: (time, level) => time + */ + protected def evalHelper(input: InternalRow, maxLevel: Int)( + truncFunc: (Any, Int) => Any): Any = { val level = if (format.foldable) { truncLevel } else { DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } - if (level == -1) { - // unknown format + if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) { + // unknown format or too large level null } else { - val d = date.eval(input) - if (d == null) { + val t = instant.eval(input) + if (t == null) { null } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + truncFunc(t, level) } } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + protected def codeGenHelper( + ctx: CodegenContext, + ev: ExprCode, + maxLevel: Int, + orderReversed: Boolean = false)( + truncFunc: (String, String) => String) + : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { - if (truncLevel == -1) { + if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - val d = date.genCode(ctx) + val t = instant.genCode(ctx) + val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); + ${ev.value} = $dtu.$truncFuncStr; }""") } } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + nullSafeCodeGen(ctx, ev, (left, right) => { val form = ctx.freshName("form") + val (dateVal, fmt) = if (orderReversed) { + (right, left) + } else { + (left, right) + } + val truncFuncStr = truncFunc(dateVal, form) s""" int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { + if ($form == -1 || $form > $maxLevel) { ${ev.isNull} = true; } else { - ${ev.value} = $dtu.truncDate($dateVal, $form); + ${ev.value} = $dtu.$truncFuncStr } """ }) @@ -1375,6 +1374,101 @@ case class TruncDate(date: Expression, format: Expression) } } +/** + * Returns date truncated to the unit specified by the format. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`. + `fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"] + """, + examples = """ + Examples: + > SELECT _FUNC_('2009-02-12', 'MM'); + 2009-02-01 + > SELECT _FUNC_('2015-10-27', 'YEAR'); + 2015-01-01 + """, + since = "1.5.0") +// scalastyle:on line.size.limit +case class TruncDate(date: Expression, format: Expression) + extends TruncInstant { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + override val instant = date + + override def eval(input: InternalRow): Any = { + evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) => + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) => + s"truncDate($date, $fmt);" + } + } +} + +/** + * Returns timestamp truncated to the unit specified by the format. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`. + `fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"] + """, + examples = """ + Examples: + > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); + 2015-01-01T00:00:00 + > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); + 2015-03-01T00:00:00 + > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); + 2015-03-05T00:00:00 + > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); + 2015-03-05T09:00:00 + """, + since = "2.3.0") +// scalastyle:on line.size.limit +case class TruncTimestamp( + format: Expression, + timestamp: Expression, + timeZoneId: Option[String] = None) + extends TruncInstant with TimeZoneAwareExpression { + override def left: Expression = format + override def right: Expression = timestamp + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) + override def dataType: TimestampType = TimestampType + override def prettyName: String = "date_trunc" + override val instant = timestamp + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + def this(format: Expression, timestamp: Expression) = this(format, timestamp, None) + + override def eval(input: InternalRow): Any = { + evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) => + DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceObj("timeZone", timeZone) + codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) { + (date: String, fmt: String) => + s"truncTimestamp($date, $fmt, $tz);" + } + } +} + /** * Returns the number of days from startDate to endDate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index b1ed25645b36c..fa69b8af62c85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,7 +45,8 @@ object DateTimeUtils { // it's 2440587.5, rounding up to compatible with Hive final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L - final val MICROS_PER_SECOND = 1000L * 1000L + final val MICROS_PER_MILLIS = 1000L + final val MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND final val MILLIS_PER_SECOND = 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY @@ -909,6 +910,15 @@ object DateTimeUtils { math.round(diff * 1e8) / 1e8 } + // Thursday = 0 since 1970/Jan/01 => Thursday + private val SUNDAY = 3 + private val MONDAY = 4 + private val TUESDAY = 5 + private val WEDNESDAY = 6 + private val THURSDAY = 0 + private val FRIDAY = 1 + private val SATURDAY = 2 + /* * Returns day of week from String. Starting from Thursday, marked as 0. * (Because 1970-01-01 is Thursday). @@ -916,13 +926,13 @@ object DateTimeUtils { def getDayOfWeekFromString(string: UTF8String): Int = { val dowString = string.toString.toUpperCase(Locale.ROOT) dowString match { - case "SU" | "SUN" | "SUNDAY" => 3 - case "MO" | "MON" | "MONDAY" => 4 - case "TU" | "TUE" | "TUESDAY" => 5 - case "WE" | "WED" | "WEDNESDAY" => 6 - case "TH" | "THU" | "THURSDAY" => 0 - case "FR" | "FRI" | "FRIDAY" => 1 - case "SA" | "SAT" | "SATURDAY" => 2 + case "SU" | "SUN" | "SUNDAY" => SUNDAY + case "MO" | "MON" | "MONDAY" => MONDAY + case "TU" | "TUE" | "TUESDAY" => TUESDAY + case "WE" | "WED" | "WEDNESDAY" => WEDNESDAY + case "TH" | "THU" | "THURSDAY" => THURSDAY + case "FR" | "FRI" | "FRIDAY" => FRIDAY + case "SA" | "SAT" | "SATURDAY" => SATURDAY case _ => -1 } } @@ -944,9 +954,16 @@ object DateTimeUtils { date + daysToMonthEnd } - private val TRUNC_TO_YEAR = 1 - private val TRUNC_TO_MONTH = 2 - private val TRUNC_INVALID = -1 + // Visible for testing. + private[sql] val TRUNC_TO_YEAR = 1 + private[sql] val TRUNC_TO_MONTH = 2 + private[sql] val TRUNC_TO_QUARTER = 3 + private[sql] val TRUNC_TO_WEEK = 4 + private[sql] val TRUNC_TO_DAY = 5 + private[sql] val TRUNC_TO_HOUR = 6 + private[sql] val TRUNC_TO_MINUTE = 7 + private[sql] val TRUNC_TO_SECOND = 8 + private[sql] val TRUNC_INVALID = -1 /** * Returns the trunc date from original date and trunc level. @@ -964,7 +981,62 @@ object DateTimeUtils { } /** - * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, + * Returns the trunc date time from original date time and trunc level. + * Trunc level should be generated using `parseTruncLevel()`, should be between 1 and 8 + */ + def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = { + var millis = t / MICROS_PER_MILLIS + val truncated = level match { + case TRUNC_TO_YEAR => + val dDays = millisToDays(millis, timeZone) + daysToMillis(truncDate(dDays, level), timeZone) + case TRUNC_TO_MONTH => + val dDays = millisToDays(millis, timeZone) + daysToMillis(truncDate(dDays, level), timeZone) + case TRUNC_TO_DAY => + val offset = timeZone.getOffset(millis) + millis += offset + millis - millis % (MILLIS_PER_SECOND * SECONDS_PER_DAY) - offset + case TRUNC_TO_HOUR => + val offset = timeZone.getOffset(millis) + millis += offset + millis - millis % (60 * 60 * MILLIS_PER_SECOND) - offset + case TRUNC_TO_MINUTE => + millis - millis % (60 * MILLIS_PER_SECOND) + case TRUNC_TO_SECOND => + millis - millis % MILLIS_PER_SECOND + case TRUNC_TO_WEEK => + val dDays = millisToDays(millis, timeZone) + val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY) + daysToMillis(prevMonday, timeZone) + case TRUNC_TO_QUARTER => + val dDays = millisToDays(millis, timeZone) + millis = daysToMillis(truncDate(dDays, TRUNC_TO_MONTH), timeZone) + val cal = Calendar.getInstance() + cal.setTimeInMillis(millis) + val quarter = getQuarter(dDays) + val month = quarter match { + case 1 => Calendar.JANUARY + case 2 => Calendar.APRIL + case 3 => Calendar.JULY + case 4 => Calendar.OCTOBER + } + cal.set(Calendar.MONTH, month) + cal.getTimeInMillis() + case _ => + // caller make sure that this should never be reached + sys.error(s"Invalid trunc level: $level") + } + truncated * MICROS_PER_MILLIS + } + + def truncTimestamp(d: SQLTimestamp, level: Int): SQLTimestamp = { + truncTimestamp(d, level, defaultTimeZone()) + } + + /** + * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, TRUNC_TO_DAY, TRUNC_TO_HOUR, + * TRUNC_TO_MINUTE, TRUNC_TO_SECOND, TRUNC_TO_WEEK, TRUNC_TO_QUARTER or TRUNC_INVALID, * TRUNC_INVALID means unsupported truncate level. */ def parseTruncLevel(format: UTF8String): Int = { @@ -974,6 +1046,12 @@ object DateTimeUtils { format.toString.toUpperCase(Locale.ROOT) match { case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH + case "DAY" | "DD" => TRUNC_TO_DAY + case "HOUR" => TRUNC_TO_HOUR + case "MINUTE" => TRUNC_TO_MINUTE + case "SECOND" => TRUNC_TO_SECOND + case "WEEK" => TRUNC_TO_WEEK + case "QUARTER" => TRUNC_TO_QUARTER case _ => TRUNC_INVALID } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 89d99f9678cda..63f6ceeb21b96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -527,7 +527,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } - test("function trunc") { + test("TruncDate") { def testTrunc(input: Date, fmt: String, expected: Date): Unit = { checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), expected) @@ -543,11 +543,82 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testTrunc(date, fmt, Date.valueOf("2015-07-01")) } testTrunc(date, "DD", null) + testTrunc(date, "SECOND", null) + testTrunc(date, "HOUR", null) testTrunc(date, null, null) testTrunc(null, "MON", null) testTrunc(null, null, null) } + test("TruncTimestamp") { + def testTrunc(input: Timestamp, fmt: String, expected: Timestamp): Unit = { + checkEvaluation( + TruncTimestamp(Literal.create(fmt, StringType), Literal.create(input, TimestampType)), + expected) + checkEvaluation( + TruncTimestamp( + NonFoldableLiteral.create(fmt, StringType), Literal.create(input, TimestampType)), + expected) + } + + withDefaultTimeZone(TimeZoneGMT) { + val inputDate = Timestamp.valueOf("2015-07-22 05:30:06") + + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-01-01 00:00:00")) + } + + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-01 00:00:00")) + } + + Seq("DAY", "day", "DD", "dd").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-22 00:00:00")) + } + + Seq("HOUR", "hour").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-22 05:00:00")) + } + + Seq("MINUTE", "minute").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-22 05:30:00")) + } + + Seq("SECOND", "second").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-22 05:30:06")) + } + + Seq("WEEK", "week").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-20 00:00:00")) + } + + Seq("QUARTER", "quarter").foreach { fmt => + testTrunc( + inputDate, fmt, + Timestamp.valueOf("2015-07-01 00:00:00")) + } + + testTrunc(inputDate, "INVALID", null) + testTrunc(inputDate, null, null) + testTrunc(null, "MON", null) + testTrunc(null, null, null) + } + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index c8cf16d937352..625ff38943fa3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -563,6 +563,76 @@ class DateTimeUtilsSuite extends SparkFunSuite { } } + test("truncTimestamp") { + def testTrunc( + level: Int, + expected: String, + inputTS: SQLTimestamp, + timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = { + val truncated = + DateTimeUtils.truncTimestamp(inputTS, level, timezone) + val expectedTS = + DateTimeUtils.stringToTimestamp(UTF8String.fromString(expected)) + assert(truncated === expectedTS.get) + } + + val defaultInputTS = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359")) + val defaultInputTS1 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359")) + val defaultInputTS2 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359")) + val defaultInputTS3 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359")) + val defaultInputTS4 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359")) + + testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS1.get) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS2.get) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", defaultInputTS3.get) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", defaultInputTS4.get) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", defaultInputTS.get) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", defaultInputTS1.get) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", defaultInputTS2.get) + + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + DateTimeTestUtils.withDefaultTimeZone(tz) { + val inputTS = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-05T09:32:05.359")) + val inputTS1 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-31T20:32:05.359")) + val inputTS2 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-04-01T02:32:05.359")) + val inputTS3 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-30T02:32:05.359")) + val inputTS4 = + DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-29T02:32:05.359")) + + testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, tz) + } + } + } + test("daysToMillis and millisToDays") { val c = Calendar.getInstance(TimeZonePST) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3e4659b9eae60..052a3f533da5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2797,6 +2797,21 @@ object functions { TruncDate(date.expr, Literal(format)) } + /** + * Returns timestamp truncated to the unit specified by the format. + * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * 'month', 'mon', 'mm' for truncate by month, + * 'day', 'dd' for truncate by day, + * Other options are: 'second', 'minute', 'hour', 'week', 'month', 'quarter' + * + * @group datetime_funcs + * @since 2.3.0 + */ + def date_trunc(format: String, timestamp: Column): Column = withExpr { + TruncTimestamp(Literal(format), timestamp.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 3a8694839bb24..6bbf38516cdf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -435,6 +435,52 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) } + test("function date_trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:01:40.523")), + (2, Timestamp.valueOf("2014-12-31 05:29:06.876"))).toDF("i", "t") + + checkAnswer( + df.select(date_trunc("YY", col("t"))), + Seq(Row(Timestamp.valueOf("2015-01-01 00:00:00")), + Row(Timestamp.valueOf("2014-01-01 00:00:00")))) + + checkAnswer( + df.selectExpr("date_trunc('MONTH', t)"), + Seq(Row(Timestamp.valueOf("2015-07-01 00:00:00")), + Row(Timestamp.valueOf("2014-12-01 00:00:00")))) + + checkAnswer( + df.selectExpr("date_trunc('DAY', t)"), + Seq(Row(Timestamp.valueOf("2015-07-22 00:00:00")), + Row(Timestamp.valueOf("2014-12-31 00:00:00")))) + + checkAnswer( + df.selectExpr("date_trunc('HOUR', t)"), + Seq(Row(Timestamp.valueOf("2015-07-22 10:00:00")), + Row(Timestamp.valueOf("2014-12-31 05:00:00")))) + + checkAnswer( + df.selectExpr("date_trunc('MINUTE', t)"), + Seq(Row(Timestamp.valueOf("2015-07-22 10:01:00")), + Row(Timestamp.valueOf("2014-12-31 05:29:00")))) + + checkAnswer( + df.selectExpr("date_trunc('SECOND', t)"), + Seq(Row(Timestamp.valueOf("2015-07-22 10:01:40")), + Row(Timestamp.valueOf("2014-12-31 05:29:06")))) + + checkAnswer( + df.selectExpr("date_trunc('WEEK', t)"), + Seq(Row(Timestamp.valueOf("2015-07-20 00:00:00")), + Row(Timestamp.valueOf("2014-12-29 00:00:00")))) + + checkAnswer( + df.selectExpr("date_trunc('QUARTER', t)"), + Seq(Row(Timestamp.valueOf("2015-07-01 00:00:00")), + Row(Timestamp.valueOf("2014-10-01 00:00:00")))) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" From 13268a58f8f67bf994f0ad5076419774c45daeeb Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Tue, 19 Dec 2017 20:47:12 -0800 Subject: [PATCH 459/712] [SPARK-22649][PYTHON][SQL] Adding localCheckpoint to Dataset API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change adds local checkpoint support to datasets and respective bind from Python Dataframe API. If reliability requirements can be lowered to favor performance, as in cases of further quick transformations followed by a reliable save, localCheckpoints() fit very well. Furthermore, at the moment Reliable checkpoints still incur double computation (see #9428) In general it makes the API more complete as well. ## How was this patch tested? Python land quick use case: ```python >>> from time import sleep >>> from pyspark.sql import types as T >>> from pyspark.sql import functions as F >>> def f(x): sleep(1) return x*2 ...: >>> df1 = spark.range(30, numPartitions=6) >>> df2 = df1.select(F.udf(f, T.LongType())("id")) >>> %time _ = df2.collect() CPU times: user 7.79 ms, sys: 5.84 ms, total: 13.6 ms Wall time: 12.2 s >>> %time df3 = df2.localCheckpoint() CPU times: user 2.38 ms, sys: 2.3 ms, total: 4.68 ms Wall time: 10.3 s >>> %time _ = df3.collect() CPU times: user 5.09 ms, sys: 410 µs, total: 5.5 ms Wall time: 148 ms >>> sc.setCheckpointDir(".") >>> %time df3 = df2.checkpoint() CPU times: user 4.04 ms, sys: 1.63 ms, total: 5.67 ms Wall time: 20.3 s ``` Author: Fernando Pereira Closes #19805 from ferdonline/feature_dataset_localCheckpoint. --- python/pyspark/sql/dataframe.py | 14 +++ .../scala/org/apache/spark/sql/Dataset.scala | 49 +++++++- .../org/apache/spark/sql/DatasetSuite.scala | 107 ++++++++++-------- 3 files changed, 121 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9864dc98c1f33..75395a754a831 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -368,6 +368,20 @@ def checkpoint(self, eager=True): jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sql_ctx) + @since(2.3) + def localCheckpoint(self, eager=True): + """Returns a locally checkpointed version of this Dataset. Checkpointing can be used to + truncate the logical plan of this DataFrame, which is especially useful in iterative + algorithms where the plan may grow exponentially. Local checkpoints are stored in the + executors using the caching subsystem and therefore they are not reliable. + + :param eager: Whether to checkpoint this DataFrame immediately + + .. note:: Experimental + """ + jdf = self._jdf.localCheckpoint(eager) + return DataFrame(jdf, self.sql_ctx) + @since(2.1) def withWatermark(self, eventTime, delayThreshold): """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c34cf0a7a7718..ef00562672a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -527,7 +527,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def checkpoint(): Dataset[T] = checkpoint(eager = true) + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the @@ -540,9 +540,52 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def checkpoint(eager: Boolean): Dataset[T] = { + def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true) + + /** + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be + * used to truncate the logical plan of this Dataset, which is especially useful in iterative + * algorithms where the plan may grow exponentially. Local checkpoints are written to executor + * storage and despite potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 2.3.0 + */ + @Experimental + @InterfaceStability.Evolving + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) + + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate + * the logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. Local checkpoints are written to executor storage and despite + * potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 2.3.0 + */ + @Experimental + @InterfaceStability.Evolving + def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint( + eager = eager, + reliableCheckpoint = false + ) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the + * checkpoint directory. If false creates a local checkpoint using + * the caching subsystem + */ + private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { val internalRdd = queryExecution.toRdd.map(_.copy()) - internalRdd.checkpoint() + if (reliableCheckpoint) { + internalRdd.checkpoint() + } else { + internalRdd.localCheckpoint() + } if (eager) { internalRdd.count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b02db7721aa7f..bd1e7adefc7a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1156,67 +1156,82 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } Seq(true, false).foreach { eager => - def testCheckpointing(testName: String)(f: => Unit): Unit = { - test(s"Dataset.checkpoint() - $testName (eager = $eager)") { - withTempDir { dir => - val originalCheckpointDir = spark.sparkContext.checkpointDir - - try { - spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + Seq(true, false).foreach { reliable => + def testCheckpointing(testName: String)(f: => Unit): Unit = { + test(s"Dataset.checkpoint() - $testName (eager = $eager, reliable = $reliable)") { + if (reliable) { + withTempDir { dir => + val originalCheckpointDir = spark.sparkContext.checkpointDir + + try { + spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + f + } finally { + // Since the original checkpointDir can be None, we need + // to set the variable directly. + spark.sparkContext.checkpointDir = originalCheckpointDir + } + } + } else { + // Local checkpoints dont require checkpoint_dir f - } finally { - // Since the original checkpointDir can be None, we need - // to set the variable directly. - spark.sparkContext.checkpointDir = originalCheckpointDir } } } - } - testCheckpointing("basic") { - val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) - val cp = ds.checkpoint(eager) + testCheckpointing("basic") { + val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) + val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) - val logicalRDD = cp.logicalPlan match { - case plan: LogicalRDD => plan - case _ => - val treeString = cp.logicalPlan.treeString(verbose = true) - fail(s"Expecting a LogicalRDD, but got\n$treeString") - } + val logicalRDD = cp.logicalPlan match { + case plan: LogicalRDD => plan + case _ => + val treeString = cp.logicalPlan.treeString(verbose = true) + fail(s"Expecting a LogicalRDD, but got\n$treeString") + } - val dsPhysicalPlan = ds.queryExecution.executedPlan - val cpPhysicalPlan = cp.queryExecution.executedPlan + val dsPhysicalPlan = ds.queryExecution.executedPlan + val cpPhysicalPlan = cp.queryExecution.executedPlan - assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning } - assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering } + assertResult(dsPhysicalPlan.outputPartitioning) { + logicalRDD.outputPartitioning + } + assertResult(dsPhysicalPlan.outputOrdering) { + logicalRDD.outputOrdering + } - assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning } - assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering } + assertResult(dsPhysicalPlan.outputPartitioning) { + cpPhysicalPlan.outputPartitioning + } + assertResult(dsPhysicalPlan.outputOrdering) { + cpPhysicalPlan.outputOrdering + } - // For a lazy checkpoint() call, the first check also materializes the checkpoint. - checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + // For a lazy checkpoint() call, the first check also materializes the checkpoint. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) - // Reads back from checkpointed data and check again. - checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) - } + // Reads back from checkpointed data and check again. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + } - testCheckpointing("should preserve partitioning information") { - val ds = spark.range(10).repartition('id % 2) - val cp = ds.checkpoint(eager) + testCheckpointing("should preserve partitioning information") { + val ds = spark.range(10).repartition('id % 2) + val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager) - val agg = cp.groupBy('id % 2).agg(count('id)) + val agg = cp.groupBy('id % 2).agg(count('id)) - agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchangeExec(_, _: RDDScanExec, _) => - case BroadcastExchangeExec(_, _: RDDScanExec) => - }.foreach { _ => - fail( - "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + - "preserves partitioning information:\n\n" + agg.queryExecution - ) - } + agg.queryExecution.executedPlan.collectFirst { + case ShuffleExchangeExec(_, _: RDDScanExec, _) => + case BroadcastExchangeExec(_, _: RDDScanExec) => + }.foreach { _ => + fail( + "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + + "preserves partitioning information:\n\n" + agg.queryExecution + ) + } - checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + } } } From 9962390af77d8085ce2a475b5983766c9c307031 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 19 Dec 2017 23:50:06 -0800 Subject: [PATCH 460/712] [SPARK-22781][SS] Support creating streaming dataset with ORC files ## What changes were proposed in this pull request? Like `Parquet`, users can use `ORC` with Apache Spark structured streaming. This PR adds `orc()` to `DataStreamReader`(Scala/Python) in order to support creating streaming dataset with ORC file format more easily like the other file formats. Also, this adds a test coverage for ORC data source and updates the document. **BEFORE** ```scala scala> spark.readStream.schema("a int").orc("/tmp/orc_ss").writeStream.format("console").start() :24: error: value orc is not a member of org.apache.spark.sql.streaming.DataStreamReader spark.readStream.schema("a int").orc("/tmp/orc_ss").writeStream.format("console").start() ``` **AFTER** ```scala scala> spark.readStream.schema("a int").orc("/tmp/orc_ss").writeStream.format("console").start() res0: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.execution.streaming.StreamingQueryWrapper678b3746 scala> ------------------------------------------- Batch: 0 ------------------------------------------- +---+ | a| +---+ | 1| +---+ ``` ## How was this patch tested? Pass the newly added test cases. Author: Dongjoon Hyun Closes #19975 from dongjoon-hyun/SPARK-22781. --- .../structured-streaming-programming-guide.md | 2 +- python/pyspark/sql/streaming.py | 17 +++ .../sql/streaming/DataStreamReader.scala | 15 +++ .../sql/streaming/FileStreamSinkSuite.scala | 4 + .../sql/streaming/FileStreamSourceSuite.scala | 111 ++++++++++++++++++ 5 files changed, 148 insertions(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 93bef8d5bb7e2..31fcfabb9cacc 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -493,7 +493,7 @@ returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with th #### Input Sources There are a few built-in sources. - - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. + - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, orc, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. - **Kafka source** - Reads data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-integration.html) for more details. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 0cf702143c773..d0aba28788ac9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -490,6 +490,23 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, else: raise TypeError("path can be only a single string") + @since(2.3) + def orc(self, path): + """Loads a ORC file stream, returning the result as a :class:`DataFrame`. + + .. note:: Evolving. + + >>> orc_sdf = spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp()) + >>> orc_sdf.isStreaming + True + >>> orc_sdf.schema == sdf_schema + True + """ + if isinstance(path, basestring): + return self._df(self._jreader.orc(path)) + else: + raise TypeError("path can be only a single string") + @since(2.0) def parquet(self, path): """Loads a Parquet file stream, returning the result as a :class:`DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index a42e28053a96a..41aa02c2b5e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -298,6 +298,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo */ def csv(path: String): DataFrame = format("csv").load(path) + /** + * Loads a ORC file stream, returning the result as a `DataFrame`. + * + * You can set the following ORC-specific option(s) for reading ORC files: + *
      + *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
    • + *
    + * + * @since 2.3.0 + */ + def orc(path: String): DataFrame = { + format("orc").load(path) + } + /** * Loads a Parquet file stream, returning the result as a `DataFrame`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 08db06b94904b..2a2552211857a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -305,6 +305,10 @@ class FileStreamSinkSuite extends StreamTest { testFormat(Some("parquet")) } + test("orc") { + testFormat(Some("orc")) + } + test("text") { testFormat(Some("text")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index c5b57bca18313..f4fa7fa7954d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -87,6 +87,28 @@ abstract class FileStreamSourceTest } } + case class AddOrcFileData(data: DataFrame, src: File, tmp: File) extends AddFileData { + override def addData(source: FileStreamSource): Unit = { + AddOrcFileData.writeToFile(data, src, tmp) + } + } + + object AddOrcFileData { + def apply(seq: Seq[String], src: File, tmp: File): AddOrcFileData = { + AddOrcFileData(seq.toDS().toDF(), src, tmp) + } + + /** Write orc files in a temp dir, and move the individual files to the 'src' dir */ + def writeToFile(df: DataFrame, src: File, tmp: File): Unit = { + val tmpDir = Utils.tempFileWith(new File(tmp, "orc")) + df.write.orc(tmpDir.getCanonicalPath) + src.mkdirs() + tmpDir.listFiles().foreach { f => + f.renameTo(new File(src, s"${f.getName}")) + } + } + } + case class AddParquetFileData(data: DataFrame, src: File, tmp: File) extends AddFileData { override def addData(source: FileStreamSource): Unit = { AddParquetFileData.writeToFile(data, src, tmp) @@ -249,6 +271,42 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + // =============== ORC file stream schema tests ================ + + test("FileStreamSource schema: orc, existing files, no schema") { + withTempDir { src => + Seq("a", "b", "c").toDS().as("userColumn").toDF().write + .mode(org.apache.spark.sql.SaveMode.Overwrite) + .orc(src.getCanonicalPath) + + // Without schema inference, should throw error + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "false") { + intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("orc"), path = Some(src.getCanonicalPath), schema = None) + } + } + + // With schema inference, should infer correct schema + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + val schema = createFileStreamSourceAndGetSchema( + format = Some("orc"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + } + + test("FileStreamSource schema: orc, existing files, schema") { + withTempPath { src => + Seq("a", "b", "c").toDS().as("oldUserColumn").toDF() + .write.orc(new File(src, "1").getCanonicalPath) + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("orc"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + // =============== Parquet file stream schema tests ================ test("FileStreamSource schema: parquet, existing files, no schema") { @@ -508,6 +566,59 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + // =============== ORC file stream tests ================ + + test("read from orc files") { + withTempDirs { case (src, tmp) => + val fileStream = createFileStream("orc", src.getCanonicalPath, Some(valueSchema)) + val filtered = fileStream.filter($"value" contains "keep") + + testStream(filtered)( + AddOrcFileData(Seq("drop1", "keep2", "keep3"), src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddOrcFileData(Seq("drop4", "keep5", "keep6"), src, tmp), + StartStream(), + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddOrcFileData(Seq("drop7", "keep8", "keep9"), src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + } + } + + test("read from orc files with changing schema") { + withTempDirs { case (src, tmp) => + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + + // Add a file so that we can infer its schema + AddOrcFileData.writeToFile(Seq("value0").toDF("k"), src, tmp) + + val fileStream = createFileStream("orc", src.getCanonicalPath) + + // FileStreamSource should infer the column "k" + assert(fileStream.schema === StructType(Seq(StructField("k", StringType)))) + + // After creating DF and before starting stream, add data with different schema + // Should not affect the inferred schema any more + AddOrcFileData.writeToFile(Seq(("value1", 0)).toDF("k", "v"), src, tmp) + + testStream(fileStream)( + // Should not pick up column v in the file added before start + AddOrcFileData(Seq("value2").toDF("k"), src, tmp), + CheckAnswer("value0", "value1", "value2"), + + // Should read data in column k, and ignore v + AddOrcFileData(Seq(("value3", 1)).toDF("k", "v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3"), + + // Should ignore rows that do not have the necessary k column + AddOrcFileData(Seq("value5").toDF("v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3", null) + ) + } + } + } + // =============== Parquet file stream tests ================ test("read from parquet files") { From 7570eab6bee57172ee3746207261307690a57b72 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 20 Dec 2017 11:31:11 -0600 Subject: [PATCH 461/712] [SPARK-22788][STREAMING] Use correct hadoop config for fs append support. Still look at the old one in case any Spark user is setting it explicitly, though. Author: Marcelo Vanzin Closes #19983 from vanzin/SPARK-22788. --- .../scala/org/apache/spark/streaming/util/HdfsUtils.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 6a3b3200dccdb..a6997359d64d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -29,7 +29,9 @@ private[streaming] object HdfsUtils { // If the file exists and we have append support, append instead of creating a new file val stream: FSDataOutputStream = { if (dfs.isFile(dfsPath)) { - if (conf.getBoolean("hdfs.append.support", false) || dfs.isInstanceOf[RawLocalFileSystem]) { + if (conf.getBoolean("dfs.support.append", true) || + conf.getBoolean("hdfs.append.support", false) || + dfs.isInstanceOf[RawLocalFileSystem]) { dfs.append(dfsPath) } else { throw new IllegalStateException("File exists and there is no append support!") From 7798c9e6ef55dbadfc9eb896fa3f366c76dc187b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 20 Dec 2017 10:43:10 -0800 Subject: [PATCH 462/712] [SPARK-22824] Restore old offset for binary compatibility ## What changes were proposed in this pull request? Some users depend on source compatibility with the org.apache.spark.sql.execution.streaming.Offset class. Although this is not a stable interface, we can keep it in place for now to simplify upgrades to 2.3. Author: Jose Torres Closes #20012 from joseph-torres/binary-compat. --- .../spark/sql/kafka010/KafkaSource.scala | 1 - .../sql/kafka010/KafkaSourceOffset.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 - .../spark/sql/sources/v2/reader/Offset.java | 6 +- .../streaming/FileStreamSource.scala | 1 - .../streaming/FileStreamSourceOffset.scala | 2 - .../sql/execution/streaming/LongOffset.scala | 2 - .../streaming/MicroBatchExecution.scala | 1 - .../spark/sql/execution/streaming/Offset.java | 61 +++++++++++++++++++ .../sql/execution/streaming/OffsetSeq.scala | 1 - .../execution/streaming/OffsetSeqLog.scala | 1 - .../streaming/RateSourceProvider.scala | 3 +- .../streaming/RateStreamOffset.scala | 4 +- .../{Offset.scala => SerializedOffset.scala} | 3 - .../sql/execution/streaming/Source.scala | 1 - .../execution/streaming/StreamExecution.scala | 1 - .../execution/streaming/StreamProgress.scala | 2 - .../ContinuousRateStreamSource.scala | 3 +- .../sql/execution/streaming/memory.scala | 1 - .../sql/execution/streaming/socket.scala | 1 - .../{ => sources}/RateStreamSourceV2.scala | 7 ++- .../streaming/{ => sources}/memoryV2.scala | 3 +- .../streaming/MemorySinkV2Suite.scala | 1 + .../streaming/RateSourceV2Suite.scala | 1 + .../sql/streaming/FileStreamSourceSuite.scala | 1 - .../spark/sql/streaming/OffsetSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 1 - .../spark/sql/streaming/StreamTest.scala | 1 - .../streaming/StreamingAggregationSuite.scala | 1 - .../StreamingQueryListenerSuite.scala | 1 - .../sql/streaming/StreamingQuerySuite.scala | 2 - .../test/DataStreamReaderWriterSuite.scala | 1 - .../sql/streaming/util/BlockingSource.scala | 3 +- 33 files changed, 82 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{Offset.scala => SerializedOffset.scala} (95%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/RateStreamSourceV2.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/memoryV2.scala (98%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 87f31fcc20ae6..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index 6e24423df4abc..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.execution.streaming.SerializedOffset -import org.apache.spark.sql.sources.v2.reader.Offset +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 9cac0e5ae7117..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java index 1ebd35356f1a3..ce1c489742054 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java @@ -23,7 +23,7 @@ * restart checkpoints. Sources should provide an Offset implementation which they can use to * reconstruct the stream position where the offset was taken. */ -public abstract class Offset { +public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is * used for saving offsets to the offset log. @@ -41,8 +41,8 @@ public abstract class Offset { */ @Override public boolean equals(Object obj) { - if (obj instanceof Offset) { - return this.json().equals(((Offset) obj).json()); + if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { + return this.json().equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index a33b785126765..0debd7db84757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -27,7 +27,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala index 431e5b99e3e98..a2b49d944a688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala @@ -22,8 +22,6 @@ import scala.util.control.Exception._ import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.sql.sources.v2.reader.Offset - /** * Offset for the [[FileStreamSource]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 7ea31462ca7b0..5f0b195fcfcb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.sources.v2.reader.Offset - /** * A simple offset for sources that produce a single linear stream of data. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index a67dda99dc01b..4a3de8bae4bc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java new file mode 100644 index 0000000000000..80aa5505db991 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming; + +/** + * This is an internal, deprecated interface. New source implementations should use the + * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported + * in the long term. + * + * This class will be removed in a future release. + */ +public abstract class Offset { + /** + * A JSON-serialized representation of an Offset that is + * used for saving offsets to the offset log. + * Note: We assume that equivalent/equal offsets serialize to + * identical JSON strings. + * + * @return JSON string encoding + */ + public abstract String json(); + + /** + * Equality based on JSON string representation. We leverage the + * JSON representation for normalization between the Offset's + * in memory and on disk representations. + */ + @Override + public boolean equals(Object obj) { + if (obj instanceof Offset) { + return this.json().equals(((Offset) obj).json()); + } else { + return false; + } + } + + @Override + public int hashCode() { + return this.json().hashCode(); + } + + @Override + public String toString() { + return this.json(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index dcc5935890c8d..4e0a468b962a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -23,7 +23,6 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} -import org.apache.spark.sql.sources.v2.reader.Offset /** * An ordered collection of offsets, used to track the progress of processing data from one or more diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index bfdbc65296165..e3f4abcf9f1dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.Offset /** * This class is used to log offsets to persistent files in HDFS. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 50671a46599e6..41761324cf6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -30,9 +30,10 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader +import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 13679dfbe446b..726d8574af52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.streaming import org.json4s.DefaultFormats import org.json4s.jackson.Serialization -import org.apache.spark.sql.sources.v2.reader.Offset +import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, (Long, Long)]) - extends Offset { + extends v2.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala index 73f0c6221c5c1..129cfed860eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SerializedOffset.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.sources.v2.reader.Offset - - /** * Used when loading a JSON serialized offset from external storage. * We are currently not responsible for converting JSON serialized diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index dbb408ffc98d8..311942f6dbd84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 7946889e85e37..129995dcf3607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 770db401c9fd7..a3f3662e6f4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.streaming import scala.collection.{immutable, GenTraversableOnce} -import org.apache.spark.sql.sources.v2.reader.Offset - /** * A helper class that looks like a Map[Source, Offset]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 77fc26730e52c..4c3a1ee201ac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -25,7 +25,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index db0717510a2cb..3041d4d703cb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 440cae016a173..0b22cbc46e6bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -31,7 +31,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 102551c238bfb..45dc7d75cbc8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.Optional @@ -27,6 +27,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.RateStreamOffset import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} @@ -59,7 +60,9 @@ class RateStreamV2Reader(options: DataSourceV2Options) private var start: RateStreamOffset = _ private var end: RateStreamOffset = _ - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + override def setOffsetRange( + start: Optional[Offset], + end: Optional[Offset]): Unit = { this.start = start.orElse( RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) .asInstanceOf[RateStreamOffset] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 437040cc12472..94c5dd63089b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import javax.annotation.concurrent.GuardedBy @@ -26,6 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} +import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index be4b490754986..00d4f0b8503d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index ef801ceb1310c..6514c5f0fdfeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index f4fa7fa7954d6..39bb572740617 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala index 429748261f1ea..f208f9bd9b6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.{LongOffset, SerializedOffset} -import org.apache.spark.sql.sources.v2.reader.Offset +import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, SerializedOffset} trait OffsetSuite extends SparkFunSuite { /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index fa4b2dd6a6c9b..755490308b5b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreCon import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index fb88c5d327043..71a474ef63e84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 38aa5171314f2..97e065193fd05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index fc9ac2a56c4e5..9ff02dee288fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ad4d3abd01aa5..2fa4595dab376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -33,12 +33,10 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ManualClock - class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { import AwaitTerminationTester._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 952908f21ca60..aa163d2211c38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _} import org.apache.spark.sql.streaming.Trigger._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala index 9a35f097e6e40..19ab2ff13e14e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.streaming.util import java.util.concurrent.CountDownLatch import org.apache.spark.sql.{SQLContext, _} -import org.apache.spark.sql.execution.streaming.{LongOffset, Sink, Source} +import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source} import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, StructField, StructType} From d762d110d410b8d67d74eb4d2950cc556ac74123 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 20 Dec 2017 12:50:03 -0600 Subject: [PATCH 463/712] [SPARK-22832][ML] BisectingKMeans unpersist unused datasets ## What changes were proposed in this pull request? unpersist unused datasets ## How was this patch tested? existing tests and local check in Spark-Shell Author: Zheng RuiFeng Closes #20017 from zhengruifeng/bkm_unpersist. --- .../spark/mllib/clustering/BisectingKMeans.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index ae98e24a75681..9b9c70cfe5109 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -197,7 +197,9 @@ class BisectingKMeans private ( newClusters = summarize(d, newAssignments) newClusterCenters = newClusters.mapValues(_.center).map(identity) } - if (preIndices != null) preIndices.unpersist() + if (preIndices != null) { + preIndices.unpersist(false) + } preIndices = indices indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys .persist(StorageLevel.MEMORY_AND_DISK) @@ -212,7 +214,13 @@ class BisectingKMeans private ( } level += 1 } - if(indices != null) indices.unpersist() + if (preIndices != null) { + preIndices.unpersist(false) + } + if (indices != null) { + indices.unpersist(false) + } + norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters val root = buildTree(clusters) new BisectingKMeansModel(root) From c89b43118347677f122db190c9033394c15cee30 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 20 Dec 2017 11:19:57 -0800 Subject: [PATCH 464/712] [SPARK-22849] ivy.retrieve pattern should also consider `classifier` ## What changes were proposed in this pull request? In the previous PR https://github.com/apache/spark/pull/5755#discussion_r157848354, we dropped `(-[classifier])` from the retrieval pattern. We should add it back; otherwise, > If this pattern for instance doesn't has the [type] or [classifier] token, Ivy will download the source/javadoc artifacts to the same file as the regular jar. ## How was this patch tested? The existing tests Author: gatorsmile Closes #20037 from gatorsmile/addClassifier. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ab834bb682041..cbe1f2c3e08a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -1271,7 +1271,7 @@ private[spark] object SparkSubmitUtils { // retrieve all resolved dependencies ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } finally { From 792915c8449b606cfdd50401fb349194a2558c36 Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 20 Dec 2017 14:47:49 -0600 Subject: [PATCH 465/712] [SPARK-22830] Scala Coding style has been improved in Spark Examples ## What changes were proposed in this pull request? * Under Spark Scala Examples: Some of the syntax were written like Java way, It has been re-written as per scala style guide. * Most of all changes are followed to println() statement. ## How was this patch tested? Since, All changes proposed are re-writing println statements in scala way, manual run used to test println. Author: chetkhatri Closes #20016 from chetkhatri/scala-style-spark-examples. --- .../apache/spark/examples/BroadcastTest.scala | 2 +- .../spark/examples/DFSReadWriteTest.scala | 24 ++++++------ .../org/apache/spark/examples/HdfsTest.scala | 2 +- .../org/apache/spark/examples/LocalALS.scala | 3 +- .../apache/spark/examples/LocalFileLR.scala | 6 +-- .../apache/spark/examples/LocalKMeans.scala | 4 +- .../org/apache/spark/examples/LocalLR.scala | 6 +-- .../org/apache/spark/examples/LocalPi.scala | 2 +- .../examples/SimpleSkewedGroupByTest.scala | 2 +- .../org/apache/spark/examples/SparkALS.scala | 4 +- .../apache/spark/examples/SparkHdfsLR.scala | 6 +-- .../apache/spark/examples/SparkKMeans.scala | 2 +- .../org/apache/spark/examples/SparkLR.scala | 6 +-- .../apache/spark/examples/SparkPageRank.scala | 2 +- .../org/apache/spark/examples/SparkPi.scala | 2 +- .../org/apache/spark/examples/SparkTC.scala | 2 +- .../spark/examples/graphx/Analytics.scala | 37 ++++++++++--------- 17 files changed, 54 insertions(+), 58 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 25718f904cc49..3311de12dbd97 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -42,7 +42,7 @@ object BroadcastTest { val arr1 = (0 until num).toArray for (i <- 0 until 3) { - println("Iteration " + i) + println(s"Iteration $i") println("===========") val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 3bff7ce736d08..1a779716ec4c0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -49,12 +49,10 @@ object DFSReadWriteTest { } private def printUsage(): Unit = { - val usage: String = "DFS Read-Write Test\n" + - "\n" + - "Usage: localFile dfsDir\n" + - "\n" + - "localFile - (string) local file to use in test\n" + - "dfsDir - (string) DFS directory for read/write tests\n" + val usage = """DFS Read-Write Test + |Usage: localFile dfsDir + |localFile - (string) local file to use in test + |dfsDir - (string) DFS directory for read/write tests""".stripMargin println(usage) } @@ -69,13 +67,13 @@ object DFSReadWriteTest { localFilePath = new File(args(i)) if (!localFilePath.exists) { - System.err.println("Given path (" + args(i) + ") does not exist.\n") + System.err.println(s"Given path (${args(i)}) does not exist") printUsage() System.exit(1) } if (!localFilePath.isFile) { - System.err.println("Given path (" + args(i) + ") is not a file.\n") + System.err.println(s"Given path (${args(i)}) is not a file") printUsage() System.exit(1) } @@ -108,7 +106,7 @@ object DFSReadWriteTest { .getOrCreate() println("Writing local file to DFS") - val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val dfsFilename = s"$dfsDirPath/dfs_read_write_test" val fileRDD = spark.sparkContext.parallelize(fileContents) fileRDD.saveAsTextFile(dfsFilename) @@ -127,11 +125,11 @@ object DFSReadWriteTest { spark.stop() if (localWordCount == dfsWordCount) { - println(s"Success! Local Word Count ($localWordCount) " + - s"and DFS Word Count ($dfsWordCount) agree.") + println(s"Success! Local Word Count $localWordCount and " + + s"DFS Word Count $dfsWordCount agree.") } else { - println(s"Failure! Local Word Count ($localWordCount) " + - s"and DFS Word Count ($dfsWordCount) disagree.") + println(s"Failure! Local Word Count $localWordCount " + + s"and DFS Word Count $dfsWordCount disagree.") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index aa8de69839e28..e1f985ece8c06 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -39,7 +39,7 @@ object HdfsTest { val start = System.currentTimeMillis() for (x <- mapped) { x + 2 } val end = System.currentTimeMillis() - println("Iteration " + iter + " took " + (end-start) + " ms") + println(s"Iteration $iter took ${end-start} ms") } spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 97aefac025e55..3f9cea35d6503 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -129,8 +129,7 @@ object LocalALS { println(s"Iteration $iter:") ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray - println("RMSE = " + rmse(R, ms, us)) - println() + println(s"RMSE = ${rmse(R, ms, us)}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 8dbb7ee4e5307..5512e33e41ac3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -58,10 +58,10 @@ object LocalFileLR { // Initialize w to a random value val w = DenseVector.fill(D) {2 * rand.nextDouble - 1} - println("Initial w: " + w) + println(s"Initial w: $w") for (i <- 1 to ITERATIONS) { - println("On iteration " + i) + println(s"On iteration $i") val gradient = DenseVector.zeros[Double](D) for (p <- points) { val scale = (1 / (1 + math.exp(-p.y * (w.dot(p.x)))) - 1) * p.y @@ -71,7 +71,7 @@ object LocalFileLR { } fileSrc.close() - println("Final w: " + w) + println(s"Final w: $w") } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 963c9a56d6cac..f5162a59522f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -88,7 +88,7 @@ object LocalKMeans { kPoints.put(i, iter.next()) } - println("Initial centers: " + kPoints) + println(s"Initial centers: $kPoints") while(tempDist > convergeDist) { val closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) @@ -114,7 +114,7 @@ object LocalKMeans { } } - println("Final centers: " + kPoints) + println(s"Final centers: $kPoints") } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index eb5221f085937..bde8ccd305960 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -61,10 +61,10 @@ object LocalLR { val data = generateData // Initialize w to a random value val w = DenseVector.fill(D) {2 * rand.nextDouble - 1} - println("Initial w: " + w) + println(s"Initial w: $w") for (i <- 1 to ITERATIONS) { - println("On iteration " + i) + println(s"On iteration $i") val gradient = DenseVector.zeros[Double](D) for (p <- data) { val scale = (1 / (1 + math.exp(-p.y * (w.dot(p.x)))) - 1) * p.y @@ -73,7 +73,7 @@ object LocalLR { w -= gradient } - println("Final w: " + w) + println(s"Final w: $w") } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 121b768e4198e..a93c15c85cfc1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -28,7 +28,7 @@ object LocalPi { val y = random * 2 - 1 if (x*x + y*y <= 1) count += 1 } - println("Pi is roughly " + 4 * count / 100000.0) + println(s"Pi is roughly ${4 * count / 100000.0}") } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 8e1a574c92221..e64dcbd182d94 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -59,7 +59,7 @@ object SimpleSkewedGroupByTest { // Enforce that everything has been calculated and in cache pairs1.count - println("RESULT: " + pairs1.groupByKey(numReducers).count) + println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") // Print how many keys each reducer got (for debugging) // println("RESULT: " + pairs1.groupByKey(numReducers) // .map{case (k,v) => (k, v.size)} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index a99ddd9fd37db..d3e7b7a967de7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -135,10 +135,8 @@ object SparkALS { .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose())) .collect() usb = sc.broadcast(us) // Re-broadcast us because it was updated - println("RMSE = " + rmse(R, ms, us)) - println() + println(s"RMSE = ${rmse(R, ms, us)}") } - spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9d675bbc18f38..23eaa879114a9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -79,17 +79,17 @@ object SparkHdfsLR { // Initialize w to a random value val w = DenseVector.fill(D) {2 * rand.nextDouble - 1} - println("Initial w: " + w) + println(s"Initial w: $w") for (i <- 1 to ITERATIONS) { - println("On iteration " + i) + println(s"On iteration $i") val gradient = points.map { p => p.x * (1 / (1 + exp(-p.y * (w.dot(p.x)))) - 1) * p.y }.reduce(_ + _) w -= gradient } - println("Final w: " + w) + println(s"Final w: $w") spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index fec3160e9f37b..b005cb6971c16 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -95,7 +95,7 @@ object SparkKMeans { for (newP <- newPoints) { kPoints(newP._1) = newP._2 } - println("Finished iteration (delta = " + tempDist + ")") + println(s"Finished iteration (delta = $tempDist)") } println("Final centers:") diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index c18e3d31f149e..4b1497345af82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -73,17 +73,17 @@ object SparkLR { // Initialize w to a random value val w = DenseVector.fill(D) {2 * rand.nextDouble - 1} - println("Initial w: " + w) + println(s"Initial w: $w") for (i <- 1 to ITERATIONS) { - println("On iteration " + i) + println(s"On iteration $i") val gradient = points.map { p => p.x * (1 / (1 + exp(-p.y * (w.dot(p.x)))) - 1) * p.y }.reduce(_ + _) w -= gradient } - println("Final w: " + w) + println(s"Final w: $w") spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 5d8831265e4ad..9299bad5d3290 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -77,7 +77,7 @@ object SparkPageRank { } val output = ranks.collect() - output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) + output.foreach(tup => println(s"${tup._1} has rank: ${tup._2} .")) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index a5cacf17a5cca..828d98b5001d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -36,7 +36,7 @@ object SparkPi { val y = random * 2 - 1 if (x*x + y*y <= 1) 1 else 0 }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / (n - 1)) + println(s"Pi is roughly ${4.0 * count / (n - 1)}") spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 558295ab928af..f5d42141f5dd2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -68,7 +68,7 @@ object SparkTC { nextCount = tc.count() } while (nextCount != oldCount) - println("TC has " + tc.count() + " edges.") + println(s"TC has ${tc.count()} edges.") spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 619e585b6ca17..92936bd30dbc0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -27,6 +27,7 @@ import org.apache.spark.graphx.lib._ import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel + /** * Driver program for running graph algorithms. */ @@ -34,12 +35,12 @@ object Analytics extends Logging { def main(args: Array[String]): Unit = { if (args.length < 2) { - System.err.println( - "Usage: Analytics --numEPart= [other options]") - System.err.println("Supported 'taskType' as follows:") - System.err.println(" pagerank Compute PageRank") - System.err.println(" cc Compute the connected components of vertices") - System.err.println(" triangles Count the number of triangles") + val usage = """Usage: Analytics --numEPart= + |[other options] Supported 'taskType' as follows: + |pagerank Compute PageRank + |cc Compute the connected components of vertices + |triangles Count the number of triangles""".stripMargin + System.err.println(usage) System.exit(1) } @@ -48,7 +49,7 @@ object Analytics extends Logging { val optionsList = args.drop(2).map { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } val options = mutable.Map(optionsList: _*) @@ -74,14 +75,14 @@ object Analytics extends Logging { val numIterOpt = options.remove("numIter").map(_.toInt) options.foreach { - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } println("======================================") println("| PageRank |") println("======================================") - val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) + val sc = new SparkContext(conf.setAppName(s"PageRank($fname)")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, numEdgePartitions = numEPart, @@ -89,18 +90,18 @@ object Analytics extends Logging { vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) - println("GRAPHX: Number of vertices " + graph.vertices.count) - println("GRAPHX: Number of edges " + graph.edges.count) + println(s"GRAPHX: Number of vertices ${graph.vertices.count}") + println(s"GRAPHX: Number of edges ${graph.edges.count}") val pr = (numIterOpt match { case Some(numIter) => PageRank.run(graph, numIter) case None => PageRank.runUntilConvergence(graph, tol) }).vertices.cache() - println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_ + _)) + println(s"GRAPHX: Total rank: ${pr.map(_._2).reduce(_ + _)}") if (!outFname.isEmpty) { - logWarning("Saving pageranks of pages to " + outFname) + logWarning(s"Saving pageranks of pages to $outFname") pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } @@ -108,14 +109,14 @@ object Analytics extends Logging { case "cc" => options.foreach { - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } println("======================================") println("| Connected Components |") println("======================================") - val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) + val sc = new SparkContext(conf.setAppName(s"ConnectedComponents($fname)")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, @@ -123,19 +124,19 @@ object Analytics extends Logging { val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) + println(s"Components: ${cc.vertices.map { case (vid, data) => data }.distinct()}") sc.stop() case "triangles" => options.foreach { - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } println("======================================") println("| Triangle Count |") println("======================================") - val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) + val sc = new SparkContext(conf.setAppName(s"TriangleCount($fname)")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, numEdgePartitions = numEPart, From b176014857b056e45cbc60ac77f1a95bc31cb0fa Mon Sep 17 00:00:00 2001 From: wuyi Date: Wed, 20 Dec 2017 14:27:56 -0800 Subject: [PATCH 466/712] [SPARK-22847][CORE] Remove redundant code in AppStatusListener while assigning schedulingPool for stage ## What changes were proposed in this pull request? In AppStatusListener's onStageSubmitted(event: SparkListenerStageSubmitted) method, there are duplicate code: ``` // schedulingPool was assigned twice with the same code stage.schedulingPool = Option(event.properties).flatMap { p => Option(p.getProperty("spark.scheduler.pool")) }.getOrElse(SparkUI.DEFAULT_POOL_NAME) ... ... ... stage.schedulingPool = Option(event.properties).flatMap { p => Option(p.getProperty("spark.scheduler.pool")) }.getOrElse(SparkUI.DEFAULT_POOL_NAME) ``` But, it does not make any sense to do this and there are no comment to explain for this. ## How was this patch tested? N/A Author: wuyi Closes #20033 from Ngone51/dev-spark-22847. --- .../scala/org/apache/spark/status/AppStatusListener.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 1fb7b76d43d04..87eb84d94c005 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -329,10 +329,6 @@ private[spark] class AppStatusListener( .toSeq stage.jobIds = stage.jobs.map(_.jobId).toSet - stage.schedulingPool = Option(event.properties).flatMap { p => - Option(p.getProperty("spark.scheduler.pool")) - }.getOrElse(SparkUI.DEFAULT_POOL_NAME) - stage.description = Option(event.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } From 0114c89d049724b95f7823b957bf33790216316b Mon Sep 17 00:00:00 2001 From: foxish Date: Wed, 20 Dec 2017 16:14:36 -0800 Subject: [PATCH 467/712] [SPARK-22845][SCHEDULER] Modify spark.kubernetes.allocation.batch.delay to take time instead of int ## What changes were proposed in this pull request? Fixing configuration that was taking an int which should take time. Discussion in https://github.com/apache/spark/pull/19946#discussion_r156682354 Made the granularity milliseconds as opposed to seconds since there's a use-case for sub-second reactions to scale-up rapidly especially with dynamic allocation. ## How was this patch tested? TODO: manual run of integration tests against this PR. PTAL cc/ mccheah liyinan926 kimoonkim vanzin mridulm jiangxb1987 ueshin Author: foxish Closes #20032 from foxish/fix-time-conf. --- .../main/scala/org/apache/spark/deploy/k8s/Config.scala | 8 ++++---- .../cluster/k8s/KubernetesClusterSchedulerBackend.scala | 2 +- .../k8s/KubernetesClusterSchedulerBackendSuite.scala | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 04aadb4b06af4..45f527959cbe1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -102,10 +102,10 @@ private[spark] object Config extends Logging { val KUBERNETES_ALLOCATION_BATCH_DELAY = ConfigBuilder("spark.kubernetes.allocation.batch.delay") - .doc("Number of seconds to wait between each round of executor allocation.") - .longConf - .checkValue(value => value > 0, "Allocation batch delay should be a positive integer") - .createWithDefault(1) + .doc("Time to wait between each round of executor allocation.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(value => value > 0, "Allocation batch delay must be a positive time value.") + .createWithDefaultString("1s") val KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS = ConfigBuilder("spark.kubernetes.executor.lostCheck.maxAttempts") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index e79c987852db2..9de4b16c30d3c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -217,7 +217,7 @@ private[spark] class KubernetesClusterSchedulerBackend( .watch(new ExecutorPodsWatcher())) allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) if (!Utils.isDynamicAllocationEnabled(conf)) { doRequestTotalExecutors(initialExecutors) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 13c09033a50ee..b2f26f205a329 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -46,7 +46,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val NAMESPACE = "test-namespace" private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = 60L + private val POD_ALLOCATION_INTERVAL = "1m" private val DRIVER_URL = RpcEndpointAddress( SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() @@ -144,7 +144,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .set(KUBERNETES_NAMESPACE, NAMESPACE) .set("spark.driver.host", SPARK_DRIVER_HOST) .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) @@ -162,8 +162,8 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn when(allocatorExecutor.scheduleWithFixedDelay( allocatorRunnable.capture(), mockitoEq(0L), - mockitoEq(POD_ALLOCATION_INTERVAL), - mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + mockitoEq(TimeUnit.MINUTES.toMillis(1)), + mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) // Creating Futures in Scala backed by a Java executor service resolves to running // ExecutorService#execute (as opposed to submit) doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) From fb0562f34605cd27fd39d09e6664a46e55eac327 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 20 Dec 2017 17:51:42 -0800 Subject: [PATCH 468/712] [SPARK-22810][ML][PYSPARK] Expose Python API for LinearRegression with huber loss. ## What changes were proposed in this pull request? Expose Python API for _LinearRegression_ with _huber_ loss. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #19994 from yanboliang/spark-22810. --- .../ml/param/_shared_params_code_gen.py | 3 +- python/pyspark/ml/param/shared.py | 23 +++++++ python/pyspark/ml/regression.py | 64 +++++++++++++++---- python/pyspark/ml/tests.py | 21 ++++++ 4 files changed, 96 insertions(+), 15 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 130d1a0bae7f0..d55d209d09398 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -154,7 +154,8 @@ def get$Name(self): ("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2", "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", - "1", "TypeConverters.toInt")] + "1", "TypeConverters.toInt"), + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4041d9c43b236..e5c5ddfba6c1f 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -632,6 +632,29 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasLoss(Params): + """ + Mixin for param loss: the loss function to be optimized. + """ + + loss = Param(Params._dummy(), "loss", "the loss function to be optimized.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasLoss, self).__init__() + + def setLoss(self, value): + """ + Sets the value of :py:attr:`loss`. + """ + return self._set(loss=value) + + def getLoss(self): + """ + Gets the value of loss or its default value. + """ + return self.getOrDefault(self.loss) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9d5b768091cf4..f0812bd1d4a39 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -39,23 +39,26 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth, + HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth, HasLoss, JavaMLWritable, JavaMLReadable): """ Linear regression. - The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ + The learning objective is to minimize the specified loss function, with regularization. + This supports two kinds of loss: - This supports multiple types of regularization: - - * none (a.k.a. ordinary least squares) + * squaredError (a.k.a squared loss) + * huber (a hybrid of squared error for relatively small errors and absolute error for \ + relatively large ones, and we estimate the scale parameter from training data) - * L2 (ridge regression) + This supports multiple types of regularization: - * L1 (Lasso) + * none (a.k.a. ordinary least squares) + * L2 (ridge regression) + * L1 (Lasso) + * L2 + L1 (elastic net) - * L2 + L1 (elastic net) + Note: Fitting with huber loss only supports none and L2 regularization. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([ @@ -98,19 +101,28 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString) + loss = Param(Params._dummy(), "loss", "The loss function to be optimized. Supported " + + "options: squaredError, huber.", typeConverter=TypeConverters.toString) + + epsilon = Param(Params._dummy(), "epsilon", "The shape parameter to control the amount of " + + "robustness. Must be > 1.0. Only valid when loss is huber", + typeConverter=TypeConverters.toFloat) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto", weightCol=None, aggregationDepth=2): + standardization=True, solver="auto", weightCol=None, aggregationDepth=2, + loss="squaredError", epsilon=1.35): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto", weightCol=None, aggregationDepth=2) + standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \ + loss="squaredError", epsilon=1.35) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -118,11 +130,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto", weightCol=None, aggregationDepth=2): + standardization=True, solver="auto", weightCol=None, aggregationDepth=2, + loss="squaredError", epsilon=1.35): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto", weightCol=None, aggregationDepth=2) + standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \ + loss="squaredError", epsilon=1.35) Sets params for linear regression. """ kwargs = self._input_kwargs @@ -131,6 +145,20 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("2.3.0") + def setEpsilon(self, value): + """ + Sets the value of :py:attr:`epsilon`. + """ + return self._set(epsilon=value) + + @since("2.3.0") + def getEpsilon(self): + """ + Gets the value of epsilon or its default value. + """ + return self.getOrDefault(self.epsilon) + class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ @@ -155,6 +183,14 @@ def intercept(self): """ return self._call_java("intercept") + @property + @since("2.3.0") + def scale(self): + """ + The value by which \|y - X'w\| is scaled down when loss is "huber", otherwise 1.0. + """ + return self._call_java("scale") + @property @since("2.0.0") def summary(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index be1521154f042..afcb0881c4dcb 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1726,6 +1726,27 @@ def test_offset(self): self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) +class LinearRegressionTest(SparkSessionTestCase): + + def test_linear_regression_with_huber_loss(self): + + data_path = "data/mllib/sample_linear_regression_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lir = LinearRegression(loss="huber", epsilon=2.0) + model = lir.fit(df) + + expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, + 1.2612, -0.333, -0.5694, -0.6311, 0.6053] + expectedIntercept = 0.1607 + expectedScale = 9.758 + + self.assertTrue( + np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) + self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) + self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) + + class LogisticRegressionTest(SparkSessionTestCase): def test_binomial_logistic_regression_with_bound(self): From 9c289a5cb46e00cd60db4794357f070dfdf80691 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 21 Dec 2017 10:02:30 +0800 Subject: [PATCH 469/712] [SPARK-22387][SQL] Propagate session configs to data source read/write options ## What changes were proposed in this pull request? Introduce a new interface `SessionConfigSupport` for `DataSourceV2`, it can help to propagate session configs with the specified key-prefix to all data source operations in this session. ## How was this patch tested? Add new test suite `DataSourceV2UtilsSuite`. Author: Xingbo Jiang Closes #19861 from jiangxb1987/datasource-configs. --- .../sql/sources/v2/SessionConfigSupport.java | 39 +++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 11 +++- .../apache/spark/sql/DataFrameWriter.scala | 15 +++-- .../datasources/v2/DataSourceV2Utils.scala | 58 +++++++++++++++++++ .../sources/v2/DataSourceV2UtilsSuite.scala | 49 ++++++++++++++++ 5 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java new file mode 100644 index 0000000000000..0b5b6ac675f2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; + +import java.util.List; +import java.util.Map; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * propagate session configs with the specified key-prefix to all data source operations in this + * session. + */ +@InterfaceStability.Evolving +public interface SessionConfigSupport { + + /** + * Key prefix of the session configs to propagate. Spark will extract all session configs that + * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` + * into `xxx -> yyy`, and propagate them to all data source operations in this session. + */ + String keyPrefix(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 39fec8f983b65..c43ee91294a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -184,9 +185,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val options = new DataSourceV2Options(extraOptions.asJava) + val ds = cls.newInstance() + val options = new DataSourceV2Options((extraOptions ++ + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) - val reader = (cls.newInstance(), userSpecifiedSchema) match { + val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 59a01e61124f7..7ccda0ad36d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,9 +30,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType /** @@ -236,14 +237,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - cls.newInstance() match { - case ds: WriteSupport => - val options = new DataSourceV2Options(extraOptions.asJava) + val ds = cls.newInstance() + ds match { + case ws: WriteSupport => + val options = new DataSourceV2Options((extraOptions ++ + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = df.sparkSession.sessionState.conf)).asJava) // Using a timestamp and a random UUID to distinguish different writing jobs. This is good // enough as there won't be tons of writing jobs created at the same second. val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) .format(new Date()) + "-" + UUID.randomUUID() - val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { WriteToDataSourceV2(writer.get(), df.logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala new file mode 100644 index 0000000000000..5267f5f1580c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.regex.Pattern + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} + +private[sql] object DataSourceV2Utils extends Logging { + + /** + * Helper method that extracts and transforms session configs into k/v pairs, the k/v pairs will + * be used to create data source options. + * Only extract when `ds` implements [[SessionConfigSupport]], in this case we may fetch the + * specified key-prefix from `ds`, and extract session configs with config keys that start with + * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will + * be transformed into `xxx -> yyy`. + * + * @param ds a [[DataSourceV2]] object + * @param conf the session conf + * @return an immutable map that contains all the extracted and transformed k/v pairs. + */ + def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } + } + + case _ => Map.empty + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala new file mode 100644 index 0000000000000..4911e3225552d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.internal.SQLConf + +class DataSourceV2UtilsSuite extends SparkFunSuite { + + private val keyPrefix = new DataSourceV2WithSessionConfig().keyPrefix + + test("method withSessionConfig() should propagate session configs correctly") { + // Only match configs with keys start with "spark.datasource.${keyPrefix}". + val conf = new SQLConf + conf.setConfString(s"spark.datasource.$keyPrefix.foo.bar", "false") + conf.setConfString(s"spark.datasource.$keyPrefix.whateverConfigName", "123") + conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") + conf.setConfString("spark.datasource.another.config.name", "123") + conf.setConfString(s"spark.datasource.$keyPrefix.", "123") + val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + assert(confs.size == 2) + assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) + assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) + assert(confs.keySet.contains("foo.bar")) + assert(confs.keySet.contains("whateverConfigName")) + } +} + +class DataSourceV2WithSessionConfig extends SimpleDataSourceV2 with SessionConfigSupport { + + override def keyPrefix: String = "userDefinedDataSource" +} From d3ae3e1e894f88a8500752d9633fe9ad00da5f20 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 20 Dec 2017 19:53:35 -0800 Subject: [PATCH 470/712] [SPARK-19634][SQL][ML][FOLLOW-UP] Improve interface of dataframe vectorized summarizer ## What changes were proposed in this pull request? Make several improvements in dataframe vectorized summarizer. 1. Make the summarizer return `Vector` type for all metrics (except "count"). It will return "WrappedArray" type before which won't be very convenient. 2. Make `MetricsAggregate` inherit `ImplicitCastInputTypes` trait. So it can check and implicitly cast input values. 3. Add "weight" parameter for all single metric method. 4. Update doc and improve the example code in doc. 5. Simplified test cases. ## How was this patch tested? Test added and simplified. Author: WeichenXu Closes #19156 from WeichenXu123/improve_vec_summarizer. --- .../org/apache/spark/ml/stat/Summarizer.scala | 128 ++++--- .../spark/ml/stat/JavaSummarizerSuite.java | 64 ++++ .../spark/ml/stat/SummarizerSuite.scala | 362 ++++++++++-------- 3 files changed, 341 insertions(+), 213 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index cae41edb7aca8..9bed74a9f2c05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -41,7 +41,7 @@ sealed abstract class SummaryBuilder { /** * Returns an aggregate object that contains the summary of the column with the requested metrics. * @param featuresCol a column that contains features Vector object. - * @param weightCol a column that contains weight value. + * @param weightCol a column that contains weight value. Default weight is 1.0. * @return an aggregate column that contains the statistics. The exact content of this * structure is determined during the creation of the builder. */ @@ -50,6 +50,7 @@ sealed abstract class SummaryBuilder { @Since("2.3.0") def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0)) + } /** @@ -60,15 +61,18 @@ sealed abstract class SummaryBuilder { * This class lets users pick the statistics they would like to extract for a given column. Here is * an example in Scala: * {{{ - * val dataframe = ... // Some dataframe containing a feature column - * val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features")) - * val Row(Row(min_, max_)) = allStats.first() + * import org.apache.spark.ml.linalg._ + * import org.apache.spark.sql.Row + * val dataframe = ... // Some dataframe containing a feature column and a weight column + * val multiStatsDF = dataframe.select( + * Summarizer.metrics("min", "max", "count").summary($"features", $"weight") + * val Row(Row(minVec, maxVec, count)) = multiStatsDF.first() * }}} * * If one wants to get a single metric, shortcuts are also available: * {{{ * val meanDF = dataframe.select(Summarizer.mean($"features")) - * val Row(mean_) = meanDF.first() + * val Row(meanVec) = meanDF.first() * }}} * * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD @@ -94,8 +98,7 @@ object Summarizer extends Logging { * - min: the minimum for each coefficient. * - normL2: the Euclidian norm for each coefficient. * - normL1: the L1 norm of each coefficient (sum of the absolute values). - * @param firstMetric the metric being provided - * @param metrics additional metrics that can be provided. + * @param metrics metrics that can be provided. * @return a builder. * @throws IllegalArgumentException if one of the metric names is not understood. * @@ -103,37 +106,79 @@ object Summarizer extends Logging { * interface. */ @Since("2.3.0") - def metrics(firstMetric: String, metrics: String*): SummaryBuilder = { - val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics) + @scala.annotation.varargs + def metrics(metrics: String*): SummaryBuilder = { + require(metrics.size >= 1, "Should include at least one metric") + val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics) new SummaryBuilderImpl(typedMetrics, computeMetrics) } @Since("2.3.0") - def mean(col: Column): Column = getSingleMetric(col, "mean") + def mean(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "mean") + } + + @Since("2.3.0") + def mean(col: Column): Column = mean(col, lit(1.0)) + + @Since("2.3.0") + def variance(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "variance") + } + + @Since("2.3.0") + def variance(col: Column): Column = variance(col, lit(1.0)) + + @Since("2.3.0") + def count(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "count") + } + + @Since("2.3.0") + def count(col: Column): Column = count(col, lit(1.0)) @Since("2.3.0") - def variance(col: Column): Column = getSingleMetric(col, "variance") + def numNonZeros(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "numNonZeros") + } + + @Since("2.3.0") + def numNonZeros(col: Column): Column = numNonZeros(col, lit(1.0)) + + @Since("2.3.0") + def max(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "max") + } + + @Since("2.3.0") + def max(col: Column): Column = max(col, lit(1.0)) @Since("2.3.0") - def count(col: Column): Column = getSingleMetric(col, "count") + def min(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "min") + } @Since("2.3.0") - def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros") + def min(col: Column): Column = min(col, lit(1.0)) @Since("2.3.0") - def max(col: Column): Column = getSingleMetric(col, "max") + def normL1(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "normL1") + } @Since("2.3.0") - def min(col: Column): Column = getSingleMetric(col, "min") + def normL1(col: Column): Column = normL1(col, lit(1.0)) @Since("2.3.0") - def normL1(col: Column): Column = getSingleMetric(col, "normL1") + def normL2(col: Column, weightCol: Column): Column = { + getSingleMetric(col, weightCol, "normL2") + } @Since("2.3.0") - def normL2(col: Column): Column = getSingleMetric(col, "normL2") + def normL2(col: Column): Column = normL2(col, lit(1.0)) - private def getSingleMetric(col: Column, metric: String): Column = { - val c1 = metrics(metric).summary(col) + private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = { + val c1 = metrics(metric).summary(col, weightCol) c1.getField(metric).as(s"$metric($col)") } } @@ -187,8 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging { StructType(fields) } - private val arrayDType = ArrayType(DoubleType, containsNull = false) - private val arrayLType = ArrayType(LongType, containsNull = false) + private val vectorUDT = new VectorUDT /** * All the metrics that can be currently computed by Spark for vectors. @@ -197,14 +241,14 @@ private[ml] object SummaryBuilderImpl extends Logging { * metrics that need to de computed internally to get the final result. */ private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq( - ("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)), - ("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), + ("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)), + ("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), ("count", Count, LongType, Seq()), - ("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)), - ("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)), - ("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)), - ("normL2", NormL2, arrayDType, Seq(ComputeM2)), - ("normL1", NormL1, arrayDType, Seq(ComputeL1)) + ("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)), + ("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)), + ("min", Min, vectorUDT, Seq(ComputeMin, ComputeNNZ)), + ("normL2", NormL2, vectorUDT, Seq(ComputeM2)), + ("normL1", NormL1, vectorUDT, Seq(ComputeL1)) ) /** @@ -527,27 +571,28 @@ private[ml] object SummaryBuilderImpl extends Logging { weightExpr: Expression, mutableAggBufferOffset: Int, inputAggBufferOffset: Int) - extends TypedImperativeAggregate[SummarizerBuffer] { + extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes { - override def eval(state: SummarizerBuffer): InternalRow = { + override def eval(state: SummarizerBuffer): Any = { val metrics = requestedMetrics.map { - case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray) - case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray) + case Mean => vectorUDT.serialize(state.mean) + case Variance => vectorUDT.serialize(state.variance) case Count => state.count - case NumNonZeros => UnsafeArrayData.fromPrimitiveArray( - state.numNonzeros.toArray.map(_.toLong)) - case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray) - case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray) - case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray) - case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray) + case NumNonZeros => vectorUDT.serialize(state.numNonzeros) + case Max => vectorUDT.serialize(state.max) + case Min => vectorUDT.serialize(state.min) + case NormL2 => vectorUDT.serialize(state.normL2) + case NormL1 => vectorUDT.serialize(state.normL1) } InternalRow.apply(metrics: _*) } + override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil + override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { - val features = udt.deserialize(featuresExpr.eval(row)) + val features = vectorUDT.deserialize(featuresExpr.eval(row)) val weight = weightExpr.eval(row).asInstanceOf[Double] state.add(features, weight) state @@ -591,7 +636,4 @@ private[ml] object SummaryBuilderImpl extends Logging { override def prettyName: String = "aggregate_metrics" } - - private[this] val udt = new VectorUDT - } diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java new file mode 100644 index 0000000000000..38ab39aa0f492 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.Dataset; +import static org.apache.spark.sql.functions.col; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; + +public class JavaSummarizerSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = new ArrayList(); + points.add(new LabeledPoint(0.0, Vectors.dense(1.0, 2.0))); + points.add(new LabeledPoint(0.0, Vectors.dense(3.0, 4.0))); + + dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @Test + public void testSummarizer() { + dataset.select(col("features")); + Row result = dataset + .select(Summarizer.metrics("mean", "max", "count").summary(col("features"))) + .first().getStruct(0); + Vector meanVec = result.getAs("mean"); + Vector maxVec = result.getAs("max"); + long count = result.getAs("count"); + + assertEquals(2L, count); + assertArrayEquals(new double[]{2.0, 3.0}, meanVec.toArray(), 0.0); + assertArrayEquals(new double[]{3.0, 4.0}, maxVec.toArray(), 0.0); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 1ea851ef2d676..5e4f402989697 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.stat -import org.scalatest.exceptions.TestFailedException - import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.Row class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -35,237 +32,262 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { import SummaryBuilderImpl._ private case class ExpectedMetrics( - mean: Seq[Double], - variance: Seq[Double], + mean: Vector, + variance: Vector, count: Long, - numNonZeros: Seq[Long], - max: Seq[Double], - min: Seq[Double], - normL2: Seq[Double], - normL1: Seq[Double]) + numNonZeros: Vector, + max: Vector, + min: Vector, + normL2: Vector, + normL1: Vector) /** - * The input is expected to be either a sparse vector, a dense vector or an array of doubles - * (which will be converted to a dense vector) - * The expected is the list of all the known metrics. + * The input is expected to be either a sparse vector, a dense vector. * - * The tests take an list of input vectors and a list of all the summary values that - * are expected for this input. They currently test against some fixed subset of the - * metrics, but should be made fuzzy in the future. + * The tests take an list of input vectors, and compare results with + * `mllib.stat.MultivariateOnlineSummarizer`. They currently test against some fixed subset + * of the metrics, but should be made fuzzy in the future. */ - private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = { + private def testExample(name: String, inputVec: Seq[(Vector, Double)], + exp: ExpectedMetrics, expWithoutWeight: ExpectedMetrics): Unit = { - def inputVec: Seq[Vector] = input.map { - case x: Array[Double @unchecked] => Vectors.dense(x) - case x: Seq[Double @unchecked] => Vectors.dense(x.toArray) - case x: Vector => x - case x => throw new Exception(x.toString) + val summarizer = { + val _summarizer = new MultivariateOnlineSummarizer + inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1), v._2)) + _summarizer } - val summarizer = { + val summarizerWithoutWeight = { val _summarizer = new MultivariateOnlineSummarizer - inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v))) + inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1))) _summarizer } // Because the Spark context is reset between tests, we cannot hold a reference onto it. def wrappedInit() = { - val df = inputVec.map(Tuple1.apply).toDF("features") - val col = df.col("features") - (df, col) + val df = inputVec.toDF("features", "weight") + val featuresCol = df.col("features") + val weightCol = df.col("weight") + (df, featuresCol, weightCol) } registerTest(s"$name - mean only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), summarizer.mean)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("mean").summary(c, w), mean(c, w)).first(), + Row(Row(summarizer.mean), exp.mean)) } - registerTest(s"$name - mean only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(mean(c)), Seq(exp.mean)) + registerTest(s"$name - mean only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("mean").summary(c), mean(c)).first(), + Row(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean)) } registerTest(s"$name - variance only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("variance").summary(c), variance(c)), - Seq(Row(exp.variance), summarizer.variance)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("variance").summary(c, w), variance(c, w)).first(), + Row(Row(summarizer.variance), exp.variance)) } - registerTest(s"$name - variance only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(variance(c)), Seq(summarizer.variance)) + registerTest(s"$name - variance only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("variance").summary(c), variance(c)).first(), + Row(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance)) } registerTest(s"$name - count only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("count").summary(c), count(c)), - Seq(Row(exp.count), exp.count)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("count").summary(c, w), count(c, w)).first(), + Row(Row(summarizer.count), exp.count)) } - registerTest(s"$name - count only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(count(c)), - Seq(exp.count)) + registerTest(s"$name - count only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("count").summary(c), count(c)).first(), + Row(Row(summarizerWithoutWeight.count), expWithoutWeight.count)) } registerTest(s"$name - numNonZeros only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), - Seq(Row(exp.numNonZeros), exp.numNonZeros)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)).first(), + Row(Row(summarizer.numNonzeros), exp.numNonZeros)) } - registerTest(s"$name - numNonZeros only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(numNonZeros(c)), - Seq(exp.numNonZeros)) + registerTest(s"$name - numNonZeros only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)).first(), + Row(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros)) } registerTest(s"$name - min only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("min").summary(c), min(c)), - Seq(Row(exp.min), exp.min)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("min").summary(c, w), min(c, w)).first(), + Row(Row(summarizer.min), exp.min)) + } + + registerTest(s"$name - min only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("min").summary(c), min(c)).first(), + Row(Row(summarizerWithoutWeight.min), expWithoutWeight.min)) } registerTest(s"$name - max only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("max").summary(c), max(c)), - Seq(Row(exp.max), exp.max)) + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("max").summary(c, w), max(c, w)).first(), + Row(Row(summarizer.max), exp.max)) } - registerTest(s"$name - normL1 only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("normL1").summary(c), normL1(c)), - Seq(Row(exp.normL1), exp.normL1)) + registerTest(s"$name - max only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("max").summary(c), max(c)).first(), + Row(Row(summarizerWithoutWeight.max), expWithoutWeight.max)) } - registerTest(s"$name - normL2 only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("normL2").summary(c), normL2(c)), - Seq(Row(exp.normL2), exp.normL2)) + registerTest(s"$name - normL1 only") { + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("normL1").summary(c, w), normL1(c, w)).first(), + Row(Row(summarizer.normL1), exp.normL1)) } - registerTest(s"$name - all metrics at once") { - val (df, c) = wrappedInit() - compare(df.select( - metrics("mean", "variance", "count", "numNonZeros").summary(c), - mean(c), variance(c), count(c), numNonZeros(c)), - Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros), - exp.mean, exp.variance, exp.count, exp.numNonZeros)) + registerTest(s"$name - normL1 only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("normL1").summary(c), normL1(c)).first(), + Row(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1)) } - } - private def denseData(input: Seq[Seq[Double]]): DataFrame = { - input.map(_.toArray).map(Vectors.dense).map(Tuple1.apply).toDF("features") - } + registerTest(s"$name - normL2 only") { + val (df, c, w) = wrappedInit() + compareRow(df.select(metrics("normL2").summary(c, w), normL2(c, w)).first(), + Row(Row(summarizer.normL2), exp.normL2)) + } - private def compare(df: DataFrame, exp: Seq[Any]): Unit = { - val coll = df.collect().toSeq - val Seq(row) = coll - val res = row.toSeq - val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" } - assert(res.size === exp.size, (res.size, exp.size)) - for (((x1, x2), name) <- res.zip(exp).zip(names)) { - compareStructures(x1, x2, name) + registerTest(s"$name - normL2 only w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select(metrics("normL2").summary(c), normL2(c)).first(), + Row(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2)) } - } - // Compares structured content. - private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match { - case (y1: Seq[Double @unchecked], v1: OldVector) => - compareStructures(y1, v1.toArray.toSeq, name) - case (d1: Double, d2: Double) => - assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name) - case (r1: GenericRowWithSchema, r2: Row) => - assert(r1.size === r2.size, (r1, r2)) - for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) { - compareStructures(x1, x2, s"$name.$fname") - } - case (r1: Row, r2: Row) => - assert(r1.size === r2.size, (r1, r2)) - for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) } - case (v1: Vector, v2: Vector) => - assert2(v1 ~== v2 absTol 1e-4, name) - case (l1: Long, l2: Long) => assert(l1 === l2) - case (s1: Seq[_], s2: Seq[_]) => - assert(s1.size === s2.size, s"$name ${(s1, s2)}") - for (((x1, idx), x2) <- s1.zipWithIndex.zip(s2)) { - compareStructures(x1, x2, s"$name.$idx") - } - case (arr1: Array[_], arr2: Array[_]) => - assert(arr1.toSeq === arr2.toSeq) - case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2") - } + registerTest(s"$name - multiple metrics at once") { + val (df, c, w) = wrappedInit() + compareRow(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c, w)).first(), + Row(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros)) + ) + } - private def assert2(x: => Boolean, hint: String): Unit = { - try { - assert(x, hint) - } catch { - case tfe: TestFailedException => - throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1) + registerTest(s"$name - multiple metrics at once w/o weight") { + val (df, c, _) = wrappedInit() + compareRow(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c)).first(), + Row(Row(expWithoutWeight.mean, expWithoutWeight.variance, + expWithoutWeight.count, expWithoutWeight.numNonZeros)) + ) } } - test("debugging test") { - val df = denseData(Nil) - val c = df.col("features") - val c1 = metrics("mean").summary(c) - val res = df.select(c1) - intercept[SparkException] { - compare(res, Seq.empty) + private def compareRow(r1: Row, r2: Row): Unit = { + assert(r1.size === r2.size, (r1, r2)) + r1.toSeq.zip(r2.toSeq).foreach { + case (v1: Vector, v2: Vector) => + assert(v1 ~== v2 absTol 1e-4) + case (v1: Vector, v2: OldVector) => + assert(v1 ~== v2.asML absTol 1e-4) + case (l1: Long, l2: Long) => + assert(l1 === l2) + case (r1: Row, r2: Row) => + compareRow(r1, r2) + case (x1: Any, x2: Any) => + throw new Exception(s"type mismatch: ${x1.getClass} ${x2.getClass} $x1 $x2") } } - test("basic error handling") { - val df = denseData(Nil) + test("no element") { + val df = Seq[Tuple1[Vector]]().toDF("features") val c = df.col("features") - val res = df.select(metrics("mean").summary(c), mean(c)) intercept[SparkException] { - compare(res, Seq.empty) + df.select(metrics("mean").summary(c), mean(c)).first() } + compareRow(df.select(metrics("count").summary(c), count(c)).first(), + Row(Row(0L), 0L)) } - test("no element, working metrics") { - val df = denseData(Nil) - val c = df.col("features") - val res = df.select(metrics("count").summary(c), count(c)) - compare(res, Seq(Row(0L), 0L)) - } + val singleElem = Vectors.dense(0.0, 1.0, 2.0) + testExample("single element", Seq((singleElem, 2.0)), + ExpectedMetrics( + mean = singleElem, + variance = Vectors.dense(0.0, 0.0, 0.0), + count = 1L, + numNonZeros = Vectors.dense(0.0, 1.0, 1.0), + max = singleElem, + min = singleElem, + normL1 = Vectors.dense(0.0, 2.0, 4.0), + normL2 = Vectors.dense(0.0, 1.414213, 2.828427) + ), + ExpectedMetrics( + mean = singleElem, + variance = Vectors.dense(0.0, 0.0, 0.0), + count = 1L, + numNonZeros = Vectors.dense(0.0, 1.0, 1.0), + max = singleElem, + min = singleElem, + normL1 = singleElem, + normL2 = singleElem + ) + ) + + testExample("multiple elements (dense)", + Seq( + (Vectors.dense(-1.0, 0.0, 6.0), 0.5), + (Vectors.dense(3.0, -3.0, 0.0), 2.8), + (Vectors.dense(1.0, -3.0, 0.0), 0.0) + ), + ExpectedMetrics( + mean = Vectors.dense(2.393939, -2.545454, 0.909090), + variance = Vectors.dense(8.0, 4.5, 18.0), + count = 2L, + numNonZeros = Vectors.dense(2.0, 1.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(8.9, 8.4, 3.0), + normL2 = Vectors.dense(5.069516, 5.019960, 4.242640) + ), + ExpectedMetrics( + mean = Vectors.dense(1.0, -2.0, 2.0), + variance = Vectors.dense(4.0, 3.0, 12.0), + count = 3L, + numNonZeros = Vectors.dense(3.0, 2.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(5.0, 6.0, 6.0), + normL2 = Vectors.dense(3.316624, 4.242640, 6.0) + ) + ) - val singleElem = Seq(0.0, 1.0, 2.0) - testExample("single element", Seq(singleElem), ExpectedMetrics( - mean = singleElem, - variance = Seq(0.0, 0.0, 0.0), - count = 1, - numNonZeros = Seq(0, 1, 1), - max = singleElem, - min = singleElem, - normL1 = singleElem, - normL2 = singleElem - )) - - testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics( - mean = Seq(0.0, 0.0, 0.0), - // TODO: I have a doubt about these values, they are not normalized. - variance = Seq(0.0, 2.0, 8.0), - count = 2, - numNonZeros = Seq(0, 2, 2), - max = Seq(0.0, 1.0, 2.0), - min = Seq(0.0, -1.0, -2.0), - normL1 = Seq(0.0, 2.0, 4.0), - normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0) - )) - - testExample("dense vector input", - Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)), + testExample("multiple elements (sparse)", + Seq( + (Vectors.dense(-1.0, 0.0, 6.0).toSparse, 0.5), + (Vectors.dense(3.0, -3.0, 0.0).toSparse, 2.8), + (Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0) + ), + ExpectedMetrics( + mean = Vectors.dense(2.393939, -2.545454, 0.909090), + variance = Vectors.dense(8.0, 4.5, 18.0), + count = 2L, + numNonZeros = Vectors.dense(2.0, 1.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(8.9, 8.4, 3.0), + normL2 = Vectors.dense(5.069516, 5.019960, 4.242640) + ), ExpectedMetrics( - mean = Seq(1.0, -1.5, 3.0), - variance = Seq(8.0, 4.5, 18.0), - count = 2, - numNonZeros = Seq(2, 1, 1), - max = Seq(3.0, 0.0, 6.0), - min = Seq(-1.0, -3, 0.0), - normL1 = Seq(4.0, 3.0, 6.0), - normL2 = Seq(math.sqrt(10), 3, 6.0) + mean = Vectors.dense(1.0, -2.0, 2.0), + variance = Vectors.dense(4.0, 3.0, 12.0), + count = 3L, + numNonZeros = Vectors.dense(3.0, 2.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(5.0, 6.0, 6.0), + normL2 = Vectors.dense(3.316624, 4.242640, 6.0) ) ) From cb9fc8d9b6d385c9d193801edaea2d52e29a90fb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 21 Dec 2017 14:54:38 +0800 Subject: [PATCH 471/712] [SPARK-22848][SQL] Eliminate mutable state from Stack ## What changes were proposed in this pull request? This PR eliminates mutable states from the generated code for `Stack`. ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #20035 from kiszk/SPARK-22848. --- .../spark/sql/catalyst/expressions/generators.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1cd73a92a8635..69af7a250a5ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -214,11 +214,11 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName - ctx.addMutableState( - s"$wrapperClass", - ev.value, - v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) - ev.copy(code = code, isNull = "false") + ev.copy(code = + s""" + |$code + |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); + """.stripMargin, isNull = "false") } } From 59d52631eb86394f1d981419cb744c20bd4e0b87 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 21 Dec 2017 20:43:56 +0900 Subject: [PATCH 472/712] [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 ## What changes were proposed in this pull request? Upgrade Spark to Arrow 0.8.0 for Java and Python. Also includes an upgrade of Netty to 4.1.17 to resolve dependency requirements. The highlights that pertain to Spark for the update from Arrow versoin 0.4.1 to 0.8.0 include: * Java refactoring for more simple API * Java reduced heap usage and streamlined hot code paths * Type support for DecimalType, ArrayType * Improved type casting support in Python * Simplified type checking in Python ## How was this patch tested? Existing tests Author: Bryan Cutler Author: Shixiong Zhu Closes #19884 from BryanCutler/arrow-upgrade-080-SPARK-22324. --- .../spark/network/crypto/TransportCipher.java | 41 +- .../network/protocol/MessageWithHeader.java | 39 +- .../spark/network/sasl/SaslEncryption.java | 41 +- .../network/util/AbstractFileRegion.java | 53 ++ .../apache/spark/network/ProtocolSuite.java | 4 +- .../protocol/MessageWithHeaderSuite.java | 7 +- .../org/apache/spark/storage/DiskStore.scala | 9 +- dev/deps/spark-deps-hadoop-2.6 | 10 +- dev/deps/spark-deps-hadoop-2.7 | 10 +- pom.xml | 12 +- python/pyspark/serializers.py | 27 +- python/pyspark/sql/dataframe.py | 2 + python/pyspark/sql/functions.py | 13 +- python/pyspark/sql/group.py | 2 +- python/pyspark/sql/session.py | 3 + python/pyspark/sql/tests.py | 12 +- python/pyspark/sql/types.py | 25 +- python/pyspark/sql/udf.py | 16 +- python/pyspark/sql/utils.py | 9 + .../vectorized/ArrowColumnVector.java | 136 ++--- .../sql/execution/arrow/ArrowConverters.scala | 13 +- .../sql/execution/arrow/ArrowWriter.scala | 132 ++-- .../execution/python/ArrowPythonRunner.scala | 27 +- .../arrow/ArrowConvertersSuite.scala | 571 ++---------------- .../vectorized/ArrowColumnVectorSuite.scala | 150 +++-- .../vectorized/ColumnarBatchSuite.scala | 20 +- 26 files changed, 515 insertions(+), 869 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 7376d1ddc4818..e04524dde0a75 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -30,10 +30,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; -import io.netty.util.AbstractReferenceCounted; import org.apache.commons.crypto.stream.CryptoInputStream; import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteArrayReadableChannel; import org.apache.spark.network.util.ByteArrayWritableChannel; @@ -161,7 +161,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { } } - private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; @@ -199,10 +199,45 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return transferred; } + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (region != null) { + region.touch(o); + } + if (buf != null) { + buf.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (region != null) { + region.retain(increment); + } + if (buf != null) { + buf.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + @Override public long transferTo(WritableByteChannel target, long position) throws IOException { Preconditions.checkArgument(position == transfered(), "Invalid position."); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 4f8781b42a0e4..897d0f9e4fb89 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -25,17 +25,17 @@ import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.AbstractFileRegion; /** * A wrapper message that holds two separate pieces (a header and a body). * * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. */ -class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { +class MessageWithHeader extends AbstractFileRegion { @Nullable private final ManagedBuffer managedBuffer; private final ByteBuf header; @@ -91,7 +91,7 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return totalBytesTransferred; } @@ -160,4 +160,37 @@ private int writeNioBuffer( return ret; } + + @Override + public MessageWithHeader touch(Object o) { + super.touch(o); + header.touch(o); + ReferenceCountUtil.touch(body, o); + return this; + } + + @Override + public MessageWithHeader retain(int increment) { + super.retain(increment); + header.retain(increment); + ReferenceCountUtil.retain(body, increment); + if (managedBuffer != null) { + for (int i = 0; i < increment; i++) { + managedBuffer.retain(); + } + } + return this; + } + + @Override + public boolean release(int decrement) { + header.release(decrement); + ReferenceCountUtil.release(body, decrement); + if (managedBuffer != null) { + for (int i = 0; i < decrement; i++) { + managedBuffer.release(); + } + } + return super.release(decrement); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3d71ebaa7ea0c..16ab4efcd4f5f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -32,8 +32,8 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.FileRegion; import io.netty.handler.codec.MessageToMessageDecoder; -import io.netty.util.AbstractReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; @@ -129,7 +129,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) } @VisibleForTesting - static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + static class EncryptedMessage extends AbstractFileRegion { private final SaslEncryptionBackend backend; private final boolean isByteBuf; @@ -183,10 +183,45 @@ public long position() { * Returns an approximation of the amount of data transferred. See {@link #count()}. */ @Override - public long transfered() { + public long transferred() { return transferred; } + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (buf != null) { + buf.touch(o); + } + if (region != null) { + region.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (buf != null) { + buf.retain(increment); + } + if (region != null) { + region.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + /** * Transfers data from the original message to the channel, encrypting it in the process. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java new file mode 100644 index 0000000000000..8651297d97ec2 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +public abstract class AbstractFileRegion extends AbstractReferenceCounted implements FileRegion { + + @Override + @SuppressWarnings("deprecation") + public final long transfered() { + return transferred(); + } + + @Override + public AbstractFileRegion retain() { + super.retain(); + return this; + } + + @Override + public AbstractFileRegion retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public AbstractFileRegion touch() { + super.touch(); + return this; + } + + @Override + public AbstractFileRegion touch(Object o) { + return this; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bb1c40c4b0e06..bc94f7ca63a96 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -56,7 +56,7 @@ private void testServerToClient(Message msg) { NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { - clientChannel.writeInbound(serverChannel.readOutbound()); + clientChannel.writeOneInbound(serverChannel.readOutbound()); } assertEquals(1, clientChannel.inboundMessages().size()); @@ -72,7 +72,7 @@ private void testClientToServer(Message msg) { NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { - serverChannel.writeInbound(clientChannel.readOutbound()); + serverChannel.writeOneInbound(clientChannel.readOutbound()); } assertEquals(1, serverChannel.inboundMessages().size()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index b341c5681e00c..ecb66fcf2ff76 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -23,8 +23,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; import org.mockito.Mockito; @@ -108,7 +107,7 @@ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exc return Unpooled.wrappedBuffer(channel.getData()); } - private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + private static class TestFileRegion extends AbstractFileRegion { private final int writeCount; private final int writesPerCall; @@ -130,7 +129,7 @@ public long position() { } @Override - public long transfered() { + public long transferred() { return 8 * written; } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 97abd92d4b70f..39249d411b582 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -26,12 +26,11 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ListBuffer import com.google.common.io.Closeables -import io.netty.channel.{DefaultFileRegion, FileRegion} -import io.netty.util.AbstractReferenceCounted +import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -266,7 +265,7 @@ private class EncryptedBlockData( } private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) - extends AbstractReferenceCounted with FileRegion { + extends AbstractFileRegion { private var _transferred = 0L @@ -277,7 +276,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def position(): Long = 0 - override def transfered(): Long = _transferred + override def transferred(): Long = _transferred override def transferTo(target: WritableByteChannel, pos: Long): Long = { assert(pos == transfered(), "Invalid position.") diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 1831f3378e852..fea642737cc0c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -82,7 +82,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar +hppc-0.7.2.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -144,7 +144,7 @@ metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar -netty-all-4.0.47.Final.jar +netty-all-4.1.17.Final.jar objenesis-2.1.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index fe14c05987327..6dd44333f21ca 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -82,7 +82,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar +hppc-0.7.2.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar @@ -145,7 +145,7 @@ metrics-json-3.1.5.jar metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar -netty-all-4.0.47.Final.jar +netty-all-4.1.17.Final.jar objenesis-2.1.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar diff --git a/pom.xml b/pom.xml index 52db79eaf036b..92f897095f087 100644 --- a/pom.xml +++ b/pom.xml @@ -185,7 +185,7 @@ 2.8 1.8 1.0.0 - 0.4.0 + 0.8.0 ${java.home} @@ -580,7 +580,7 @@ io.netty netty-all - 4.0.47.Final + 4.1.17.Final io.netty @@ -1972,6 +1972,14 @@ com.fasterxml.jackson.core jackson-databind + + io.netty + netty-buffer + + + io.netty + netty-common + io.netty netty-handler diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 37e7cf3fa662e..88d6a191babca 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -223,27 +223,14 @@ def _create_batch(series, timezone): series = [series] series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if type(t) == pa.TimestampType: - # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\ - .values.astype('datetime64[us]', copy=False) - # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 - elif t is not None and t == pa.date32(): - # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 - return s.dt.date - elif t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - # Some object types don't support masks in Arrow, see ARROW-1721 def create_array(s, t): - casted = cast_series(s, t) - mask = None if casted.dtype == 'object' else s.isnull() - return pa.Array.from_pandas(casted, mask=mask, type=t) + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) + # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 + return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 75395a754a831..440684d3edfa6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1906,7 +1906,9 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.utils import _require_minimum_pyarrow_version import pyarrow + _require_minimum_pyarrow_version() tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 54530055dfa85..ddd8df3b15bf6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2159,16 +2159,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(StringType()) + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP + >>> @pandas_udf(StringType()) # doctest: +SKIP ... def to_upper(s): ... return s.str.upper() ... - >>> @pandas_udf("integer", PandasUDFType.SCALAR) + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP ... def add_one(x): ... return x + 1 ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df = spark.createDataFrame([(1, "John Doe", 21)], + ... ("id", "name", "age")) # doctest: +SKIP >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ ... .show() # doctest: +SKIP +----------+--------------+------------+ @@ -2189,8 +2190,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 4d47dd6a3e878..09fae46adf014 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -218,7 +218,7 @@ def apply(self, udf): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e2435e09af23d..86db16eca7889 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -495,11 +495,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): from pyspark.serializers import ArrowSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ _old_pandas_exception_message, TimestampType + from pyspark.sql.utils import _require_minimum_pyarrow_version try: from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype except ImportError as e: raise ImportError(_old_pandas_exception_message(e)) + _require_minimum_pyarrow_version() + # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b4d32d8de8a22..6fdfda1cc831b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3339,10 +3339,11 @@ def test_createDataFrame_with_single_data_type(self): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): + import pandas as pd # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = 1 + pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3356,6 +3357,7 @@ def test_schema_conversion_roundtrip(self): self.assertEquals(self.schema, schema_rt) +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3671,9 +3673,9 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, StringType()) + f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3974,12 +3976,12 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, - 'id long, v string', + 'id long, v array', PandasUDFType.GROUP_MAP ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.groupby('id').apply(foo).sort('id').toPandas() def test_wrong_args(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 78abc32a35a1c..46d9a417414b5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1642,29 +1642,28 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ - # TODO: newer pyarrow has is_boolean(at) functions that would be better to check type - import pyarrow as pa - if at == pa.bool_(): + import pyarrow.types as types + if types.is_boolean(at): spark_type = BooleanType() - elif at == pa.int8(): + elif types.is_int8(at): spark_type = ByteType() - elif at == pa.int16(): + elif types.is_int16(at): spark_type = ShortType() - elif at == pa.int32(): + elif types.is_int32(at): spark_type = IntegerType() - elif at == pa.int64(): + elif types.is_int64(at): spark_type = LongType() - elif at == pa.float32(): + elif types.is_float32(at): spark_type = FloatType() - elif at == pa.float64(): + elif types.is_float64(at): spark_type = DoubleType() - elif type(at) == pa.DecimalType: + elif types.is_decimal(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) - elif at == pa.string(): + elif types.is_string(at): spark_type = StringType() - elif at == pa.date32(): + elif types.is_date32(at): spark_type = DateType() - elif type(at) == pa.TimestampType: + elif types.is_timestamp(at): spark_type = TimestampType() else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c3301a41ccd5a..50c87ba1ac882 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -33,19 +33,23 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ + evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: import inspect + from pyspark.sql.utils import _require_minimum_pyarrow_version + + _require_minimum_pyarrow_version() argspec = inspect.getargspec(f) - if len(argspec.args) == 0 and argspec.varargs is None: + + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - import inspect - argspec = inspect.getargspec(f) - if len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: raise ValueError( "Invalid function: pandas_udfs with function type GROUP_MAP " "must take a single arg that is a pandas DataFrame." diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7bc6a59ad3b26..cc7dabb64b3ec 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -110,3 +110,12 @@ def toJArray(gateway, jtype, arr): for i in range(0, len(arr)): jarr[i] = arr[i] return jarr + + +def _require_minimum_pyarrow_version(): + """ Raise ImportError if minimum version of pyarrow is not installed + """ + from distutils.version import LooseVersion + import pyarrow + if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): + raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index d53e1fcab0c5a..528f66f342dc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -275,30 +275,30 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); - if (vector instanceof NullableBitVector) { - accessor = new BooleanAccessor((NullableBitVector) vector); - } else if (vector instanceof NullableTinyIntVector) { - accessor = new ByteAccessor((NullableTinyIntVector) vector); - } else if (vector instanceof NullableSmallIntVector) { - accessor = new ShortAccessor((NullableSmallIntVector) vector); - } else if (vector instanceof NullableIntVector) { - accessor = new IntAccessor((NullableIntVector) vector); - } else if (vector instanceof NullableBigIntVector) { - accessor = new LongAccessor((NullableBigIntVector) vector); - } else if (vector instanceof NullableFloat4Vector) { - accessor = new FloatAccessor((NullableFloat4Vector) vector); - } else if (vector instanceof NullableFloat8Vector) { - accessor = new DoubleAccessor((NullableFloat8Vector) vector); - } else if (vector instanceof NullableDecimalVector) { - accessor = new DecimalAccessor((NullableDecimalVector) vector); - } else if (vector instanceof NullableVarCharVector) { - accessor = new StringAccessor((NullableVarCharVector) vector); - } else if (vector instanceof NullableVarBinaryVector) { - accessor = new BinaryAccessor((NullableVarBinaryVector) vector); - } else if (vector instanceof NullableDateDayVector) { - accessor = new DateAccessor((NullableDateDayVector) vector); - } else if (vector instanceof NullableTimeStampMicroTZVector) { - accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); + if (vector instanceof BitVector) { + accessor = new BooleanAccessor((BitVector) vector); + } else if (vector instanceof TinyIntVector) { + accessor = new ByteAccessor((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + accessor = new ShortAccessor((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + accessor = new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + accessor = new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + accessor = new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + accessor = new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof DecimalVector) { + accessor = new DecimalAccessor((DecimalVector) vector); + } else if (vector instanceof VarCharVector) { + accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof VarBinaryVector) { + accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -321,23 +321,21 @@ public ArrowColumnVector(ValueVector vector) { private abstract static class ArrowVectorAccessor { private final ValueVector vector; - private final ValueVector.Accessor nulls; ArrowVectorAccessor(ValueVector vector) { this.vector = vector; - this.nulls = vector.getAccessor(); } final boolean isNullAt(int rowId) { - return nulls.isNull(rowId); + return vector.isNull(rowId); } final int getValueCount() { - return nulls.getValueCount(); + return vector.getValueCount(); } final int getNullCount() { - return nulls.getNullCount(); + return vector.getNullCount(); } final void close() { @@ -395,11 +393,11 @@ int getArrayOffset(int rowId) { private static class BooleanAccessor extends ArrowVectorAccessor { - private final NullableBitVector.Accessor accessor; + private final BitVector accessor; - BooleanAccessor(NullableBitVector vector) { + BooleanAccessor(BitVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -410,11 +408,11 @@ final boolean getBoolean(int rowId) { private static class ByteAccessor extends ArrowVectorAccessor { - private final NullableTinyIntVector.Accessor accessor; + private final TinyIntVector accessor; - ByteAccessor(NullableTinyIntVector vector) { + ByteAccessor(TinyIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -425,11 +423,11 @@ final byte getByte(int rowId) { private static class ShortAccessor extends ArrowVectorAccessor { - private final NullableSmallIntVector.Accessor accessor; + private final SmallIntVector accessor; - ShortAccessor(NullableSmallIntVector vector) { + ShortAccessor(SmallIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -440,11 +438,11 @@ final short getShort(int rowId) { private static class IntAccessor extends ArrowVectorAccessor { - private final NullableIntVector.Accessor accessor; + private final IntVector accessor; - IntAccessor(NullableIntVector vector) { + IntAccessor(IntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -455,11 +453,11 @@ final int getInt(int rowId) { private static class LongAccessor extends ArrowVectorAccessor { - private final NullableBigIntVector.Accessor accessor; + private final BigIntVector accessor; - LongAccessor(NullableBigIntVector vector) { + LongAccessor(BigIntVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -470,11 +468,11 @@ final long getLong(int rowId) { private static class FloatAccessor extends ArrowVectorAccessor { - private final NullableFloat4Vector.Accessor accessor; + private final Float4Vector accessor; - FloatAccessor(NullableFloat4Vector vector) { + FloatAccessor(Float4Vector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -485,11 +483,11 @@ final float getFloat(int rowId) { private static class DoubleAccessor extends ArrowVectorAccessor { - private final NullableFloat8Vector.Accessor accessor; + private final Float8Vector accessor; - DoubleAccessor(NullableFloat8Vector vector) { + DoubleAccessor(Float8Vector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -500,11 +498,11 @@ final double getDouble(int rowId) { private static class DecimalAccessor extends ArrowVectorAccessor { - private final NullableDecimalVector.Accessor accessor; + private final DecimalVector accessor; - DecimalAccessor(NullableDecimalVector vector) { + DecimalAccessor(DecimalVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -516,12 +514,12 @@ final Decimal getDecimal(int rowId, int precision, int scale) { private static class StringAccessor extends ArrowVectorAccessor { - private final NullableVarCharVector.Accessor accessor; + private final VarCharVector accessor; private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); - StringAccessor(NullableVarCharVector vector) { + StringAccessor(VarCharVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -539,11 +537,11 @@ final UTF8String getUTF8String(int rowId) { private static class BinaryAccessor extends ArrowVectorAccessor { - private final NullableVarBinaryVector.Accessor accessor; + private final VarBinaryVector accessor; - BinaryAccessor(NullableVarBinaryVector vector) { + BinaryAccessor(VarBinaryVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -554,11 +552,11 @@ final byte[] getBinary(int rowId) { private static class DateAccessor extends ArrowVectorAccessor { - private final NullableDateDayVector.Accessor accessor; + private final DateDayVector accessor; - DateAccessor(NullableDateDayVector vector) { + DateAccessor(DateDayVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -569,11 +567,11 @@ final int getInt(int rowId) { private static class TimestampAccessor extends ArrowVectorAccessor { - private final NullableTimeStampMicroTZVector.Accessor accessor; + private final TimeStampMicroTZVector accessor; - TimestampAccessor(NullableTimeStampMicroTZVector vector) { + TimestampAccessor(TimeStampMicroTZVector vector) { super(vector); - this.accessor = vector.getAccessor(); + this.accessor = vector; } @Override @@ -584,21 +582,21 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { - private final UInt4Vector.Accessor accessor; + private final ListVector accessor; ArrayAccessor(ListVector vector) { super(vector); - this.accessor = vector.getOffsetVector().getAccessor(); + this.accessor = vector; } @Override final int getArrayLength(int rowId) { - return accessor.get(rowId + 1) - accessor.get(rowId); + return accessor.getInnerValueCountAt(rowId); } @Override final int getArrayOffset(int rowId) { - return accessor.get(rowId); + return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 3cafb344ef553..bcfc412430263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.ArrowRecordBatch +import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext @@ -86,13 +86,9 @@ private[sql] object ArrowConverters { val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - var closed = false - context.addTaskCompletionListener { _ => - if (!closed) { - root.close() - allocator.close() - } + root.close() + allocator.close() } new Iterator[ArrowPayload] { @@ -100,7 +96,6 @@ private[sql] object ArrowConverters { override def hasNext: Boolean = rowIter.hasNext || { root.close() allocator.close() - closed = true false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index e4af4f65da127..0258056d9de49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -46,17 +46,17 @@ object ArrowWriter { private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { - case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector) - case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector) - case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector) - case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector) - case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) - case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) - case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) - case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) - case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) - case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) - case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) + case (BooleanType, vector: BitVector) => new BooleanWriter(vector) + case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: IntVector) => new IntegerWriter(vector) + case (LongType, vector: BigIntVector) => new LongWriter(vector) + case (FloatType, vector: Float4Vector) => new FloatWriter(vector) + case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: DateDayVector) => new DateWriter(vector) + case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -103,7 +103,6 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { private[arrow] abstract class ArrowFieldWriter { def valueVector: ValueVector - def valueMutator: ValueVector.Mutator def name: String = valueVector.getField().getName() def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) @@ -124,161 +123,144 @@ private[arrow] abstract class ArrowFieldWriter { } def finish(): Unit = { - valueMutator.setValueCount(count) + valueVector.setValueCount(count) } def reset(): Unit = { - valueMutator.reset() + // TODO: reset() should be in a common interface + valueVector match { + case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() + case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case _ => + } count = 0 } } -private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { - - override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() +private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) } } -private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() +private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getByte(ordinal)) + valueVector.setSafe(count, input.getByte(ordinal)) } } -private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() +private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getShort(ordinal)) + valueVector.setSafe(count, input.getShort(ordinal)) } } -private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() +private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getInt(ordinal)) + valueVector.setSafe(count, input.getInt(ordinal)) } } -private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { - - override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() +private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getLong(ordinal)) + valueVector.setSafe(count, input.getLong(ordinal)) } } -private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { - - override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() +private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getFloat(ordinal)) + valueVector.setSafe(count, input.getFloat(ordinal)) } } -private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { - - override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() +private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getDouble(ordinal)) + valueVector.setSafe(count, input.getDouble(ordinal)) } } -private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { - - override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() +private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val utf8 = input.getUTF8String(ordinal) val utf8ByteBuffer = utf8.getByteBuffer // todo: for off-heap UTF8String, how to pass in to arrow without copy? - valueMutator.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) + valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) } } private[arrow] class BinaryWriter( - val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { - - override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() + val valueVector: VarBinaryVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val bytes = input.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) + valueVector.setSafe(count, bytes, 0, bytes.length) } } -private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { - - override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() +private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getInt(ordinal)) + valueVector.setSafe(count, input.getInt(ordinal)) } } private[arrow] class TimestampWriter( - val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { - - override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter { override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setSafe(count, input.getLong(ordinal)) + valueVector.setSafe(count, input.getLong(ordinal)) } } @@ -286,20 +268,18 @@ private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { - override def valueMutator: ListVector#Mutator = valueVector.getMutator() - override def setNull(): Unit = { } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val array = input.getArray(ordinal) var i = 0 - valueMutator.startNewValue(count) + valueVector.startNewValue(count) while (i < array.numElements()) { elementWriter.write(array, i) i += 1 } - valueMutator.endValue(count, array.numElements()) + valueVector.endValue(count, array.numElements()) } override def finish(): Unit = { @@ -317,8 +297,6 @@ private[arrow] class StructWriter( val valueVector: NullableMapVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { - override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator() - override def setNull(): Unit = { var i = 0 while (i < children.length) { @@ -326,7 +304,7 @@ private[arrow] class StructWriter( children(i).count += 1 i += 1 } - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { @@ -336,7 +314,7 @@ private[arrow] class StructWriter( children(i).write(struct, i) i += 1 } - valueMutator.setIndexDefined(count) + valueVector.setIndexDefined(count) } override def finish(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 9a94d771a01b0..5cc8ed3535654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.spark._ import org.apache.spark.api.python._ @@ -74,13 +74,9 @@ class ArrowPythonRunner( val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - var closed = false - context.addTaskCompletionListener { _ => - if (!closed) { - root.close() - allocator.close() - } + root.close() + allocator.close() } val writer = new ArrowStreamWriter(root, null, dataOut) @@ -102,7 +98,6 @@ class ArrowPythonRunner( writer.end() root.close() allocator.close() - closed = true } } } @@ -126,18 +121,11 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - private var closed = false - context.addTaskCompletionListener { _ => - // todo: we need something like `reader.end()`, which release all the resources, but leave - // the input stream open. `reader.close()` will close the socket and we can't reuse worker. - // So here we simply not close the reader, which is problematic. - if (!closed) { - if (root != null) { - root.close() - } - allocator.close() + if (reader != null) { + reader.close(false) } + allocator.close() } private var batchLoaded = true @@ -154,9 +142,8 @@ class ArrowPythonRunner( batch.setNumRows(root.getRowCount) batch } else { - root.close() + reader.close(false) allocator.close() - closed = true // Reach end of stream. Call `read()` again to read control data. read() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 57958f7239224..fd5a3df6abc68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -25,7 +25,7 @@ import java.util.Locale import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} -import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.ipc.JsonFileReader import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll @@ -76,16 +76,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_s", | "type" : { @@ -94,16 +85,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -143,16 +125,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_i", | "type" : { @@ -161,16 +134,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -210,16 +174,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_l", | "type" : { @@ -228,16 +183,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -276,16 +222,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_f", | "type" : { @@ -293,16 +230,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -341,16 +269,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_d", | "type" : { @@ -358,16 +277,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -408,16 +318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -449,16 +350,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 16 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -466,16 +358,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "c", | "type" : { @@ -484,16 +367,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "d", | "type" : { @@ -501,16 +375,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | }, { | "name" : "e", | "type" : { @@ -519,16 +384,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 64 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -583,57 +439,21 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | }, { | "name" : "lower_case", | "type" : { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | }, { | "name" : "null_str", | "type" : { | "name" : "utf8" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -681,16 +501,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "bool" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -721,16 +532,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 8 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -760,19 +562,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "binary" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -807,16 +597,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "unit" : "DAY" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -855,16 +636,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "timezone" : "America/Los_Angeles" | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -904,16 +676,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "SINGLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "NaN_d", | "type" : { @@ -921,16 +684,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "precision" : "DOUBLE" | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -939,12 +693,12 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "NaN_f", | "count" : 2, | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1.2000000476837158, "NaN" ] + | "DATA" : [ 1.2000000476837158, NaN ] | }, { | "name" : "NaN_d", | "count" : 2, | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ "NaN", 1.2 ] + | "DATA" : [ NaN, 1.2 ] | } ] | } ] |} @@ -976,26 +730,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "b_arr", | "nullable" : true, @@ -1010,26 +746,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "c_arr", | "nullable" : true, @@ -1044,26 +762,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "d_arr", | "nullable" : true, @@ -1084,36 +784,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 + | "children" : [ ] | } ] - | } + | } ] | } ] | }, | "batches" : [ { @@ -1204,23 +877,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "b_struct", | "nullable" : true, @@ -1235,23 +893,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "c_struct", | "nullable" : false, @@ -1266,23 +909,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } + | "children" : [ ] + | } ] | }, { | "name" : "d_struct", | "nullable" : true, @@ -1303,30 +931,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32, | "isSigned" : true | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | } ] - | } - | } ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 + | "children" : [ ] | } ] - | } + | } ] | } ] | }, | "batches" : [ { @@ -1413,16 +1020,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -1431,16 +1029,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1471,16 +1060,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b", | "type" : { @@ -1489,16 +1069,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1600,16 +1171,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "b_i", | "type" : { @@ -1618,16 +1180,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -1658,16 +1211,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | }, { | "name" : "a_i", | "type" : { @@ -1676,16 +1220,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "bitWidth" : 32 | }, | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } + | "children" : [ ] | } ] | }, | "batches" : [ { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index e460d0721e7bf..03490ad15a655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -30,15 +30,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableBitVector] + .createVector(allocator).asInstanceOf[BitVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, if (i % 2 == 0) 1 else 0) + vector.setSafe(i, if (i % 2 == 0) 1 else 0) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) @@ -58,15 +57,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableTinyIntVector] + .createVector(allocator).asInstanceOf[TinyIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toByte) + vector.setSafe(i, i.toByte) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) @@ -86,15 +84,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableSmallIntVector] + .createVector(allocator).asInstanceOf[SmallIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toShort) + vector.setSafe(i, i.toShort) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) @@ -114,15 +111,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i) + vector.setSafe(i, i) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) @@ -142,15 +138,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableBigIntVector] + .createVector(allocator).asInstanceOf[BigIntVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toLong) + vector.setSafe(i, i.toLong) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) @@ -170,15 +165,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableFloat4Vector] + .createVector(allocator).asInstanceOf[Float4Vector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toFloat) + vector.setSafe(i, i.toFloat) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) @@ -198,15 +192,14 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableFloat8Vector] + .createVector(allocator).asInstanceOf[Float8Vector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => - mutator.setSafe(i, i.toDouble) + vector.setSafe(i, i.toDouble) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) @@ -226,16 +219,15 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableVarCharVector] + .createVector(allocator).asInstanceOf[VarCharVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") - mutator.setSafe(i, utf8, 0, utf8.length) + vector.setSafe(i, utf8, 0, utf8.length) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) @@ -253,16 +245,15 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableVarBinaryVector] + .createVector(allocator).asInstanceOf[VarBinaryVector] vector.allocateNew() - val mutator = vector.getMutator() (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") - mutator.setSafe(i, utf8, 0, utf8.length) + vector.setSafe(i, utf8, 0, utf8.length) } - mutator.setNull(10) - mutator.setValueCount(11) + vector.setNull(10) + vector.setValueCount(11) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) @@ -282,31 +273,29 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() - val mutator = vector.getMutator() - val elementVector = vector.getDataVector().asInstanceOf[NullableIntVector] - val elementMutator = elementVector.getMutator() + val elementVector = vector.getDataVector().asInstanceOf[IntVector] // [1, 2] - mutator.startNewValue(0) - elementMutator.setSafe(0, 1) - elementMutator.setSafe(1, 2) - mutator.endValue(0, 2) + vector.startNewValue(0) + elementVector.setSafe(0, 1) + elementVector.setSafe(1, 2) + vector.endValue(0, 2) // [3, null, 5] - mutator.startNewValue(1) - elementMutator.setSafe(2, 3) - elementMutator.setNull(3) - elementMutator.setSafe(4, 5) - mutator.endValue(1, 3) + vector.startNewValue(1) + elementVector.setSafe(2, 3) + elementVector.setNull(3) + elementVector.setSafe(4, 5) + vector.endValue(1, 3) // null // [] - mutator.startNewValue(3) - mutator.endValue(3, 0) + vector.startNewValue(3) + vector.endValue(3, 0) - elementMutator.setValueCount(5) - mutator.setValueCount(4) + elementVector.setValueCount(5) + vector.setValueCount(4) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) @@ -338,38 +327,35 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() - val mutator = vector.getMutator() - val intVector = vector.getChildByOrdinal(0).asInstanceOf[NullableIntVector] - val intMutator = intVector.getMutator() - val longVector = vector.getChildByOrdinal(1).asInstanceOf[NullableBigIntVector] - val longMutator = longVector.getMutator() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] // (1, 1L) - mutator.setIndexDefined(0) - intMutator.setSafe(0, 1) - longMutator.setSafe(0, 1L) + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) // (2, null) - mutator.setIndexDefined(1) - intMutator.setSafe(1, 2) - longMutator.setNull(1) + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) // (null, 3L) - mutator.setIndexDefined(2) - intMutator.setNull(2) - longMutator.setSafe(2, 3L) + vector.setIndexDefined(2) + intVector.setNull(2) + longVector.setSafe(2, 3L) // null - mutator.setNull(3) + vector.setNull(3) // (5, 5L) - mutator.setIndexDefined(4) - intMutator.setSafe(4, 5) - longMutator.setSafe(4, 5L) + vector.setIndexDefined(4) + intVector.setSafe(4, 5) + longVector.setSafe(4, 5L) - intMutator.setValueCount(5) - longMutator.setValueCount(5) - mutator.setValueCount(5) + intVector.setValueCount(5) + longVector.setValueCount(5) + vector.setValueCount(5) val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index d3ed8276b8f10..7848ebdcab6d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random -import org.apache.arrow.vector.NullableIntVector +import org.apache.arrow.vector.IntVector import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode @@ -1137,22 +1137,20 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector1.allocateNew() - val mutator1 = vector1.getMutator() val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableIntVector] + .createVector(allocator).asInstanceOf[IntVector] vector2.allocateNew() - val mutator2 = vector2.getMutator() (0 until 10).foreach { i => - mutator1.setSafe(i, i) - mutator2.setSafe(i + 1, i) + vector1.setSafe(i, i) + vector2.setSafe(i + 1, i) } - mutator1.setNull(10) - mutator1.setValueCount(11) - mutator2.setNull(0) - mutator2.setValueCount(11) + vector1.setNull(10) + vector1.setValueCount(11) + vector2.setNull(0) + vector2.setValueCount(11) val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) From 0abaf31be7ab9e030ea9433938b9123596954814 Mon Sep 17 00:00:00 2001 From: Erik LaBianca Date: Thu, 21 Dec 2017 09:38:21 -0600 Subject: [PATCH 473/712] [SPARK-22852][BUILD] Exclude -Xlint:unchecked from sbt javadoc flags ## What changes were proposed in this pull request? Moves the -Xlint:unchecked flag in the sbt build configuration from Compile to (Compile, compile) scope, allowing publish and publishLocal commands to work. ## How was this patch tested? Successfully published the spark-launcher subproject from within sbt successfully, where it fails without this patch. Author: Erik LaBianca Closes #20040 from easel/javadoc-xlint. --- project/SparkBuild.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 75703380cdb4a..83054945afcf0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -239,14 +239,14 @@ object SparkBuild extends PomBuild { javacOptions in Compile ++= Seq( "-encoding", "UTF-8", - "-source", javacJVMVersion.value, - "-Xlint:unchecked" + "-source", javacJVMVersion.value ), - // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't - // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for - // additional discussion and explanation. + // This -target and Xlint:unchecked options cannot be set in the Compile configuration scope since + // `javadoc` doesn't play nicely with them; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 + // for additional discussion and explanation. javacOptions in (Compile, compile) ++= Seq( - "-target", javacJVMVersion.value + "-target", javacJVMVersion.value, + "-Xlint:unchecked" ), scalacOptions in Compile ++= Seq( From 4c2efde9314a5f67052ac87bfa1472ebb9aca74a Mon Sep 17 00:00:00 2001 From: Erik LaBianca Date: Thu, 21 Dec 2017 10:08:38 -0600 Subject: [PATCH 474/712] [SPARK-22855][BUILD] Add -no-java-comments to sbt docs/scalacOptions Prevents Scala 2.12 scaladoc from blowing up attempting to parse java comments. ## What changes were proposed in this pull request? Adds -no-java-comments to docs/scalacOptions under Scala 2.12. Also moves scaladoc configs out of the TestSettings and into the standard sharedSettings section in SparkBuild.scala. ## How was this patch tested? SBT_OPTS=-Dscala-2.12 sbt ++2.12.4 tags/publishLocal Author: Erik LaBianca Closes #20042 from easel/scaladoc-212. --- project/SparkBuild.scala | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 83054945afcf0..7469f11df0294 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -254,6 +254,21 @@ object SparkBuild extends PomBuild { "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc ), + // Remove certain packages from Scaladoc + scalacOptions in (Compile, doc) := Seq( + "-groups", + "-skip-packages", Seq( + "org.apache.spark.api.python", + "org.apache.spark.network", + "org.apache.spark.deploy", + "org.apache.spark.util.collection" + ).mkString(":"), + "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc" + ) ++ { + // Do not attempt to scaladoc javadoc comments under 2.12 since it can't handle inner classes + if (scalaBinaryVersion.value == "2.12") Seq("-no-java-comments") else Seq.empty + }, + // Implements -Xfatal-warnings, ignoring deprecation warnings. // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. compile in Compile := { @@ -828,18 +843,7 @@ object TestSettings { } Seq.empty[File] }).value, - concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), - // Remove certain packages from Scaladoc - scalacOptions in (Compile, doc) := Seq( - "-groups", - "-skip-packages", Seq( - "org.apache.spark.api.python", - "org.apache.spark.network", - "org.apache.spark.deploy", - "org.apache.spark.util.collection" - ).mkString(":"), - "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc" - ) + concurrentRestrictions in Global += Tags.limit(Tags.Test, 1) ) } From 8a0ed5a5ee64a6e854c516f80df5a9729435479b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 22 Dec 2017 00:21:27 +0800 Subject: [PATCH 475/712] [SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions() ## What changes were proposed in this pull request? Passing global variables to the split method is dangerous, as any mutating to it is ignored and may lead to unexpected behavior. To prevent this, one approach is to make sure no expression would output global variables: Localizing lifetime of mutable states in expressions. Another approach is, when calling `ctx.splitExpression`, make sure we don't use children's output as parameter names. Approach 1 is actually hard to do, as we need to check all expressions and operators that support whole-stage codegen. Approach 2 is easier as the callers of `ctx.splitExpressions` are not too many. Besides, approach 2 is more flexible, as children's output may be other stuff that can't be parameter name: literal, inlined statement(a + 1), etc. close https://github.com/apache/spark/pull/19865 close https://github.com/apache/spark/pull/19938 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20021 from cloud-fan/codegen. --- .../sql/catalyst/expressions/arithmetic.scala | 18 +++++------ .../expressions/codegen/CodeGenerator.scala | 32 ++++++++++++++++--- .../expressions/conditionalExpressions.scala | 8 ++--- .../expressions/nullExpressions.scala | 9 +++--- .../sql/catalyst/expressions/predicates.scala | 2 +- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d3a8cb5804717..8bb14598a6d7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } @@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 41a920ba3d677..9adf632ddcde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -128,7 +128,7 @@ class CodegenContext { * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling * `Expression.genCode`. */ - final var INPUT_ROW = "i" + var INPUT_ROW = "i" /** * Holding a list of generated columns as input of current operator, will be used by @@ -146,22 +146,30 @@ class CodegenContext { * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. + * + * Exposed for tests only. */ - val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = + private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty[(String, String)] /** * The mapping between mutable state types and corrseponding compacted arrays. * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates * the compacted arrays for the mutable states with the same java type. + * + * Exposed for tests only. */ - val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = + private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] // An array holds the code that will initialize each state - val mutableStateInitCode: mutable.ArrayBuffer[String] = + // Exposed for tests only. + private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + // Tracks the names of all the mutable states. + private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty + /** * This class holds a set of names of mutableStateArrays that is used for compacting mutable * states for a certain type, and holds the next available slot of the current compacted array. @@ -172,7 +180,11 @@ class CodegenContext { private[this] var currentIndex = 0 - private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + private def createNewArray() = { + val newArrayName = freshName("mutableStateArray") + mutableStateNames += newArrayName + arrayNames.append(newArrayName) + } def getCurrentIndex: Int = currentIndex @@ -241,6 +253,7 @@ class CodegenContext { val initCode = initFunc(varName) inlinedMutableStates += ((javaType, varName)) mutableStateInitCode += initCode + mutableStateNames += varName varName } else { val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) @@ -930,6 +943,15 @@ class CodegenContext { // inline execution if only one block blocks.head } else { + if (Utils.isTesting) { + // Passing global variables to the split method is dangerous, as any mutating to it is + // ignored and may lead to unexpected behavior. + arguments.foreach { case (_, name) => + assert(!mutableStateNames.contains(name), + s"split function argument $name cannot be a global variable.") + } + } + val func = freshName(funcName) val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") val functions = blocks.zipWithIndex.map { case (body, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1a9b68222a7f4..142dfb02be0a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,7 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") + ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -205,7 +205,7 @@ case class CaseWhen( |if (!${cond.isNull} && ${cond.value}) { | ${res.code} | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - | $tmpResult = ${res.value}; + | ${ev.value} = ${res.value}; | continue; |} """.stripMargin @@ -216,7 +216,7 @@ case class CaseWhen( s""" |${res.code} |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - |$tmpResult = ${res.value}; + |${ev.value} = ${res.value}; """.stripMargin } @@ -264,13 +264,11 @@ case class CaseWhen( ev.copy(code = s""" |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; - |$tmpResult = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); |// TRUE if any condition is met and the result is null, or no any condition is met. |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); - |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b4f895fffda38..470d5da041ea5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { foldFunctions = _.map { funcCall => s""" |${ev.value} = $funcCall; - |if (!$tmpIsNull) { + |if (!${ev.isNull}) { | continue; |} """.stripMargin @@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac9f56f78eb2e..f4ee3d10f3f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { - | $tmpResult = 0; + | $tmpResult = $NOT_MATCHED; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes From d3a1d9527bcd6675cc45773f01d4558cf4b46b3d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 22 Dec 2017 01:08:13 +0800 Subject: [PATCH 476/712] [SPARK-22786][SQL] only use AppStatusPlugin in history server ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19681 we introduced a new interface called `AppStatusPlugin`, to register listeners and set up the UI for both live and history UI. However I think it's an overkill for live UI. For example, we should not register `SQLListener` if users are not using SQL functions. Previously we register the `SQLListener` and set up SQL tab when `SparkSession` is firstly created, which indicates users are going to use SQL functions. But in #19681 , we register the SQL functions during `SparkContext` creation. The same thing should apply to streaming too. I think we should keep the previous behavior, and only use this new interface for history server. To reflect this change, I also rename the new interface to `SparkHistoryUIPlugin` This PR also refines the tests for sql listener. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #19981 from cloud-fan/listener. --- .../scala/org/apache/spark/SparkContext.scala | 16 +- .../deploy/history/FsHistoryProvider.scala | 14 +- .../spark/status/AppHistoryServerPlugin.scala | 38 +++ .../spark/status/AppStatusListener.scala | 2 +- .../apache/spark/status/AppStatusPlugin.scala | 71 ------ .../apache/spark/status/AppStatusStore.scala | 17 +- .../org/apache/spark/ui/StagePageSuite.scala | 20 +- ...apache.spark.status.AppHistoryServerPlugin | 1 + .../org.apache.spark.status.AppStatusPlugin | 1 - .../execution/ui/SQLAppStatusListener.scala | 23 +- .../sql/execution/ui/SQLAppStatusStore.scala | 62 +---- .../execution/ui/SQLHistoryServerPlugin.scala | 36 +++ .../spark/sql/internal/SharedState.scala | 18 +- .../execution/metric/SQLMetricsSuite.scala | 8 - .../metric/SQLMetricsTestUtils.scala | 10 +- ....scala => SQLAppStatusListenerSuite.scala} | 238 +++++++++--------- 16 files changed, 256 insertions(+), 319 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppHistoryServerPlugin.scala delete mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/ui/{SQLListenerSuite.scala => SQLAppStatusListenerSuite.scala} (71%) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 92e13ce1ba042..fcbeddd2a9ac3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -53,7 +53,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend -import org.apache.spark.status.{AppStatusPlugin, AppStatusStore} +import org.apache.spark.status.AppStatusStore import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -416,7 +416,8 @@ class SparkContext(config: SparkConf) extends Logging { // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. - _statusStore = AppStatusStore.createLiveStore(conf, l => listenerBus.addToStatusQueue(l)) + _statusStore = AppStatusStore.createLiveStore(conf) + listenerBus.addToStatusQueue(_statusStore.listener.get) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -445,14 +446,9 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } - _ui.foreach { ui => - // Load any plugins that might want to modify the UI. - AppStatusPlugin.loadPlugins().foreach(_.setupUI(ui)) - - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.bind() - } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index fa2c5194aa41b..a299b79850613 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -44,7 +44,6 @@ import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} -import org.apache.spark.status.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -322,15 +321,18 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) (new InMemoryStore(), true) } + val plugins = ServiceLoader.load( + classOf[AppHistoryServerPlugin], Utils.getContextOrSparkClassLoader).asScala val trackingStore = new ElementTrackingStore(kvstore, conf) if (needReplay) { val replayBus = new ReplayListenerBus() val listener = new AppStatusListener(trackingStore, conf, false, lastUpdateTime = Some(attempt.info.lastUpdated.getTime())) replayBus.addListener(listener) - AppStatusPlugin.loadPlugins().foreach { plugin => - plugin.setupListeners(conf, trackingStore, l => replayBus.addListener(l), false) - } + for { + plugin <- plugins + listener <- plugin.createListeners(conf, trackingStore) + } replayBus.addListener(listener) try { val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) @@ -353,9 +355,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.info.attemptId), attempt.info.startTime.getTime(), attempt.info.appSparkVersion) - AppStatusPlugin.loadPlugins().foreach { plugin => - plugin.setupUI(ui) - } + plugins.foreach(_.setupUI(ui)) val loadedUI = LoadedAppUI(ui) diff --git a/core/src/main/scala/org/apache/spark/status/AppHistoryServerPlugin.scala b/core/src/main/scala/org/apache/spark/status/AppHistoryServerPlugin.scala new file mode 100644 index 0000000000000..d144a0e998fa1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppHistoryServerPlugin.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.ui.SparkUI + +/** + * An interface for creating history listeners(to replay event logs) defined in other modules like + * SQL, and setup the UI of the plugin to rebuild the history UI. + */ +private[spark] trait AppHistoryServerPlugin { + /** + * Creates listeners to replay the event logs. + */ + def createListeners(conf: SparkConf, store: ElementTrackingStore): Seq[SparkListener] + + /** + * Sets up UI of this plugin to rebuild the history UI. + */ + def setupUI(ui: SparkUI): Unit +} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 87eb84d94c005..4db797e1d24c6 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -48,7 +48,7 @@ private[spark] class AppStatusListener( import config._ - private var sparkVersion = SPARK_VERSION + private val sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var appSummary = new AppSummary(0, 0) private var coresPerTask: Int = 1 diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala b/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala deleted file mode 100644 index 4cada5c7b0de4..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/AppStatusPlugin.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.status - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler.SparkListener -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils -import org.apache.spark.util.kvstore.KVStore - -/** - * An interface that defines plugins for collecting and storing application state. - * - * The plugin implementations are invoked for both live and replayed applications. For live - * applications, it's recommended that plugins defer creation of UI tabs until there's actual - * data to be shown. - */ -private[spark] trait AppStatusPlugin { - - /** - * Install listeners to collect data about the running application and populate the given - * store. - * - * @param conf The Spark configuration. - * @param store The KVStore where to keep application data. - * @param addListenerFn Function to register listeners with a bus. - * @param live Whether this is a live application (or an application being replayed by the - * HistoryServer). - */ - def setupListeners( - conf: SparkConf, - store: ElementTrackingStore, - addListenerFn: SparkListener => Unit, - live: Boolean): Unit - - /** - * Install any needed extensions (tabs, pages, etc) to a Spark UI. The plugin can detect whether - * the app is live or replayed by looking at the UI's SparkContext field `sc`. - * - * @param ui The Spark UI instance for the application. - */ - def setupUI(ui: SparkUI): Unit - -} - -private[spark] object AppStatusPlugin { - - def loadPlugins(): Iterable[AppStatusPlugin] = { - ServiceLoader.load(classOf[AppStatusPlugin], Utils.getContextOrSparkClassLoader).asScala - } - -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 9987419b170f6..5a942f5284018 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -17,16 +17,14 @@ package org.apache.spark.status -import java.io.File -import java.util.{Arrays, List => JList} +import java.util.{List => JList} import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} -import org.apache.spark.scheduler.SparkListener import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Distribution import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -34,7 +32,7 @@ import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} */ private[spark] class AppStatusStore( val store: KVStore, - listener: Option[AppStatusListener] = None) { + val listener: Option[AppStatusListener] = None) { def applicationInfo(): v1.ApplicationInfo = { store.view(classOf[ApplicationInfoWrapper]).max(1).iterator().next().info @@ -346,17 +344,10 @@ private[spark] object AppStatusStore { /** * Create an in-memory store for a live application. - * - * @param conf Configuration. - * @param addListenerFn Function to register a listener with a bus. */ - def createLiveStore(conf: SparkConf, addListenerFn: SparkListener => Unit): AppStatusStore = { + def createLiveStore(conf: SparkConf): AppStatusStore = { val store = new ElementTrackingStore(new InMemoryStore(), conf) val listener = new AppStatusListener(store, conf, true) - addListenerFn(listener) - AppStatusPlugin.loadPlugins().foreach { p => - p.setupListeners(conf, store, addListenerFn, true) - } new AppStatusStore(store, listener = Some(listener)) } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 46932a02f1a1b..661d0d48d2f37 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -22,7 +22,6 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.mockito.Matchers.anyString import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ @@ -30,7 +29,6 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore import org.apache.spark.ui.jobs.{StagePage, StagesTab} -import org.apache.spark.util.Utils class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -55,12 +53,12 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * This also runs a dummy stage to populate the page with useful content. */ private def renderStagePage(conf: SparkConf): Seq[Node] = { - val bus = new ReplayListenerBus() - val store = AppStatusStore.createLiveStore(conf, l => bus.addListener(l)) + val statusStore = AppStatusStore.createLiveStore(conf) + val listener = statusStore.listener.get try { val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) - when(tab.store).thenReturn(store) + when(tab.store).thenReturn(statusStore) val request = mock(classOf[HttpServletRequest]) when(tab.conf).thenReturn(conf) @@ -68,7 +66,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { when(tab.headerTabs).thenReturn(Seq.empty) when(request.getParameter("id")).thenReturn("0") when(request.getParameter("attempt")).thenReturn("0") - val page = new StagePage(tab, store) + val page = new StagePage(tab, statusStore) // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") @@ -77,17 +75,17 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - bus.postToAll(SparkListenerStageSubmitted(stageInfo)) - bus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + listener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) - bus.postToAll(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) + listener.onTaskEnd(SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } - bus.postToAll(SparkListenerStageCompleted(stageInfo)) + listener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } finally { - store.close() + statusStore.close() } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin new file mode 100644 index 0000000000000..0bba2f88b92a5 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryServerPlugin diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin b/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin deleted file mode 100644 index ac6d7f6962f85..0000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.status.AppStatusPlugin +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLAppStatusPlugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index cf0000c6393a3..aa78fa015dbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -30,15 +30,11 @@ import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} import org.apache.spark.status.config._ -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.kvstore.KVStore -private[sql] class SQLAppStatusListener( +class SQLAppStatusListener( conf: SparkConf, kvstore: ElementTrackingStore, - live: Boolean, - ui: Option[SparkUI] = None) - extends SparkListener with Logging { + live: Boolean) extends SparkListener with Logging { // How often to flush intermediate state of a live execution to the store. When replaying logs, // never flush (only do the very last write). @@ -50,7 +46,10 @@ private[sql] class SQLAppStatusListener( private val liveExecutions = new ConcurrentHashMap[Long, LiveExecutionData]() private val stageMetrics = new ConcurrentHashMap[Int, LiveStageMetrics]() - private var uiInitialized = false + // Returns true if this listener has no live data. Exposed for tests only. + private[sql] def noLiveData(): Boolean = { + liveExecutions.isEmpty && stageMetrics.isEmpty + } kvstore.addTrigger(classOf[SQLExecutionUIData], conf.get(UI_RETAINED_EXECUTIONS)) { count => cleanupExecutions(count) @@ -230,14 +229,6 @@ private[sql] class SQLAppStatusListener( } private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { - // Install the SQL tab in a live app if it hasn't been initialized yet. - if (!uiInitialized) { - ui.foreach { _ui => - new SQLTab(new SQLAppStatusStore(kvstore, Some(this)), _ui) - } - uiInitialized = true - } - val SparkListenerSQLExecutionStart(executionId, description, details, physicalPlanDescription, sparkPlanInfo, time) = event @@ -389,7 +380,7 @@ private class LiveStageMetrics( val accumulatorIds: Array[Long], val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) -private[sql] class LiveTaskMetrics( +private class LiveTaskMetrics( val ids: Array[Long], val values: Array[Long], val succeeded: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 7fd5f7395cdf3..910f2e52fdbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -25,21 +25,17 @@ import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.databind.annotation.JsonDeserialize -import org.apache.spark.{JobExecutionStatus, SparkConf} -import org.apache.spark.scheduler.SparkListener -import org.apache.spark.status.{AppStatusPlugin, ElementTrackingStore} +import org.apache.spark.JobExecutionStatus import org.apache.spark.status.KVUtils.KVIndexParam -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils import org.apache.spark.util.kvstore.KVStore /** * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's * no state kept in this class, so it's ok to have multiple instances of it in an application. */ -private[sql] class SQLAppStatusStore( +class SQLAppStatusStore( store: KVStore, - listener: Option[SQLAppStatusListener] = None) { + val listener: Option[SQLAppStatusListener] = None) { def executionsList(): Seq[SQLExecutionUIData] = { store.view(classOf[SQLExecutionUIData]).asScala.toSeq @@ -74,48 +70,9 @@ private[sql] class SQLAppStatusStore( def planGraph(executionId: Long): SparkPlanGraph = { store.read(classOf[SparkPlanGraphWrapper], executionId).toSparkPlanGraph() } - -} - -/** - * An AppStatusPlugin for handling the SQL UI and listeners. - */ -private[sql] class SQLAppStatusPlugin extends AppStatusPlugin { - - override def setupListeners( - conf: SparkConf, - store: ElementTrackingStore, - addListenerFn: SparkListener => Unit, - live: Boolean): Unit = { - // For live applications, the listener is installed in [[setupUI]]. This also avoids adding - // the listener when the UI is disabled. Force installation during testing, though. - if (!live || Utils.isTesting) { - val listener = new SQLAppStatusListener(conf, store, live, None) - addListenerFn(listener) - } - } - - override def setupUI(ui: SparkUI): Unit = { - ui.sc match { - case Some(sc) => - // If this is a live application, then install a listener that will enable the SQL - // tab as soon as there's a SQL event posted to the bus. - val listener = new SQLAppStatusListener(sc.conf, - ui.store.store.asInstanceOf[ElementTrackingStore], true, Some(ui)) - sc.listenerBus.addToStatusQueue(listener) - - case _ => - // For a replayed application, only add the tab if the store already contains SQL data. - val sqlStore = new SQLAppStatusStore(ui.store.store) - if (sqlStore.executionsCount() > 0) { - new SQLTab(sqlStore, ui) - } - } - } - } -private[sql] class SQLExecutionUIData( +class SQLExecutionUIData( @KVIndexParam val executionId: Long, val description: String, val details: String, @@ -133,10 +90,9 @@ private[sql] class SQLExecutionUIData( * from the SQL listener instance. */ @JsonDeserialize(keyAs = classOf[JLong]) - val metricValues: Map[Long, String] - ) + val metricValues: Map[Long, String]) -private[sql] class SparkPlanGraphWrapper( +class SparkPlanGraphWrapper( @KVIndexParam val executionId: Long, val nodes: Seq[SparkPlanGraphNodeWrapper], val edges: Seq[SparkPlanGraphEdge]) { @@ -147,7 +103,7 @@ private[sql] class SparkPlanGraphWrapper( } -private[sql] class SparkPlanGraphClusterWrapper( +class SparkPlanGraphClusterWrapper( val id: Long, val name: String, val desc: String, @@ -163,7 +119,7 @@ private[sql] class SparkPlanGraphClusterWrapper( } /** Only one of the values should be set. */ -private[sql] class SparkPlanGraphNodeWrapper( +class SparkPlanGraphNodeWrapper( val node: SparkPlanGraphNode, val cluster: SparkPlanGraphClusterWrapper) { @@ -174,7 +130,7 @@ private[sql] class SparkPlanGraphNodeWrapper( } -private[sql] case class SQLPlanMetric( +case class SQLPlanMetric( name: String, accumulatorId: Long, metricType: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala new file mode 100644 index 0000000000000..522d0cf79bffa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLHistoryServerPlugin.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.SparkListener +import org.apache.spark.status.{AppHistoryServerPlugin, ElementTrackingStore} +import org.apache.spark.ui.SparkUI + +class SQLHistoryServerPlugin extends AppHistoryServerPlugin { + override def createListeners(conf: SparkConf, store: ElementTrackingStore): Seq[SparkListener] = { + Seq(new SQLAppStatusListener(conf, store, live = false)) + } + + override def setupUI(ui: SparkUI): Unit = { + val sqlStatusStore = new SQLAppStatusStore(ui.store.store) + if (sqlStatusStore.executionsCount() > 0) { + new SQLTab(sqlStatusStore, ui) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 3e479faed72ac..baea4ceebf8e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -28,11 +28,12 @@ import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager +import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab} import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.status.ElementTrackingStore import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -82,6 +83,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { */ val cacheManager: CacheManager = new CacheManager + /** + * A status store to query SQL status/metrics of this Spark application, based on SQL-specific + * [[org.apache.spark.scheduler.SparkListenerEvent]]s. + */ + val statusStore: SQLAppStatusStore = { + val kvStore = sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore] + val listener = new SQLAppStatusListener(sparkContext.conf, kvStore, live = true) + sparkContext.listenerBus.addToStatusQueue(listener) + val statusStore = new SQLAppStatusStore(kvStore, Some(listener)) + sparkContext.ui.foreach(new SQLTab(statusStore, _)) + statusStore + } + /** * A catalog that interacts with external systems. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d588af3e19dde..fc3483379c817 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -33,14 +33,6 @@ import org.apache.spark.util.{AccumulatorContext, JsonProtocol} class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext { import testImplicits._ - private def statusStore: SQLAppStatusStore = { - new SQLAppStatusStore(sparkContext.statusStore.store) - } - - private def currentExecutionIds(): Set[Long] = { - statusStore.executionsList.map(_.executionId).toSet - } - /** * Generates a `DataFrame` by filling randomly generated bytes for hash collision. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index d89c4b14619fa..122d28798136f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -31,17 +31,14 @@ import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { - import testImplicits._ - private def statusStore: SQLAppStatusStore = { - new SQLAppStatusStore(sparkContext.statusStore.store) - } - - private def currentExecutionIds(): Set[Long] = { + protected def currentExecutionIds(): Set[Long] = { statusStore.executionsList.map(_.executionId).toSet } + protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -57,7 +54,6 @@ trait SQLMetricsTestUtils extends SQLTestUtils { assert(executionIds.size == 1) val executionId = executionIds.head - val executionData = statusStore.execution(executionId).get val executedNode = statusStore.planGraph(executionId).nodes.head val metricsNames = Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala similarity index 71% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 932950687942c..5ebbeb4a7cb40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -41,7 +41,8 @@ import org.apache.spark.status.config._ import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} import org.apache.spark.util.kvstore.InMemoryStore -class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { + +class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils { import testImplicits._ override protected def sparkConf = { @@ -61,21 +62,21 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest properties } - private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = new StageInfo( - stageId = stageId, - attemptId = attemptId, - // The following fields are not used in tests - name = "", - numTasks = 0, - rddInfos = Nil, - parentIds = Nil, - details = "" - ) + private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = { + new StageInfo(stageId = stageId, + attemptId = attemptId, + // The following fields are not used in tests + name = "", + numTasks = 0, + rddInfos = Nil, + parentIds = Nil, + details = "") + } private def createTaskInfo( taskId: Int, attemptNumber: Int, - accums: Map[Long, Long] = Map()): TaskInfo = { + accums: Map[Long, Long] = Map.empty): TaskInfo = { val info = new TaskInfo( taskId = taskId, attemptNumber = attemptNumber, @@ -99,29 +100,37 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest }.toSeq } - /** Return the shared SQL store from the active SparkSession. */ - private def statusStore: SQLAppStatusStore = - new SQLAppStatusStore(spark.sparkContext.statusStore.store) - - /** - * Runs a test with a temporary SQLAppStatusStore tied to a listener bus. Events can be sent to - * the listener bus to update the store, and all data will be cleaned up at the end of the test. - */ - private def sqlStoreTest(name: String) - (fn: (SQLAppStatusStore, SparkListenerBus) => Unit): Unit = { - test(name) { - val conf = sparkConf - val store = new ElementTrackingStore(new InMemoryStore(), conf) - val bus = new ReplayListenerBus() - val listener = new SQLAppStatusListener(conf, store, true) - bus.addListener(listener) - store.close(false) - val sqlStore = new SQLAppStatusStore(store, Some(listener)) - fn(sqlStore, bus) + private def assertJobs( + exec: Option[SQLExecutionUIData], + running: Seq[Int] = Nil, + completed: Seq[Int] = Nil, + failed: Seq[Int] = Nil): Unit = { + val actualRunning = new ListBuffer[Int]() + val actualCompleted = new ListBuffer[Int]() + val actualFailed = new ListBuffer[Int]() + + exec.get.jobs.foreach { case (jobId, jobStatus) => + jobStatus match { + case JobExecutionStatus.RUNNING => actualRunning += jobId + case JobExecutionStatus.SUCCEEDED => actualCompleted += jobId + case JobExecutionStatus.FAILED => actualFailed += jobId + case _ => fail(s"Unexpected status $jobStatus") + } } + + assert(actualRunning.sorted === running) + assert(actualCompleted.sorted === completed) + assert(actualFailed.sorted === failed) } - sqlStoreTest("basic") { (store, bus) => + private def createStatusStore(): SQLAppStatusStore = { + val conf = sparkContext.conf + val store = new ElementTrackingStore(new InMemoryStore, conf) + val listener = new SQLAppStatusListener(conf, store, live = true) + new SQLAppStatusStore(store, Some(listener)) + } + + test("basic") { def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { assert(actual.size == expected.size) expected.foreach { case (id, value) => @@ -135,6 +144,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest } } + val statusStore = createStatusStore() + val listener = statusStore.listener.get + val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -147,7 +159,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest (id, accumulatorValue) }.toMap - bus.postToAll(SparkListenerSQLExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", @@ -155,7 +167,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - bus.postToAll(SparkListenerJobStart( + listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq( @@ -163,45 +175,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createStageInfo(1, 0) ), createProperties(executionId))) - bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 0))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) - assert(store.executionMetrics(0).isEmpty) + assert(statusStore.executionMetrics(executionId).isEmpty) - bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) // Driver accumulator updates don't belong to this execution should be filtered and no // exception will be thrown. - bus.postToAll(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) + listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) - bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2))) ))) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 3)) // Retrying a stage should reset the metrics - bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) - bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 1, createAccumulatorInfos(accumulatorUpdates)), (1L, 0, 1, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) // Ignore the task end for the first attempt - bus.postToAll(SparkListenerTaskEnd( + listener.onTaskEnd(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 0, taskType = "", @@ -209,17 +221,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100)), null)) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) // Finish two tasks - bus.postToAll(SparkListenerTaskEnd( + listener.onTaskEnd(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", reason = null, createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2)), null)) - bus.postToAll(SparkListenerTaskEnd( + listener.onTaskEnd(SparkListenerTaskEnd( stageId = 0, stageAttemptId = 1, taskType = "", @@ -227,28 +239,28 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), null)) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 5)) // Summit a new stage - bus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 0))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) - bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq( + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 1, 0, createAccumulatorInfos(accumulatorUpdates)), (1L, 1, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 7)) // Finish two tasks - bus.postToAll(SparkListenerTaskEnd( + listener.onTaskEnd(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", reason = null, createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3)), null)) - bus.postToAll(SparkListenerTaskEnd( + listener.onTaskEnd(SparkListenerTaskEnd( stageId = 1, stageAttemptId = 0, taskType = "", @@ -256,127 +268,137 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), null)) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 11)) - assertJobs(store.execution(0), running = Seq(0)) + assertJobs(statusStore.execution(executionId), running = Seq(0)) - bus.postToAll(SparkListenerJobEnd( + listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - bus.postToAll(SparkListenerSQLExecutionEnd( + listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - assertJobs(store.execution(0), completed = Seq(0)) - checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) + assertJobs(statusStore.execution(executionId), completed = Seq(0)) + + checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 11)) } - sqlStoreTest("onExecutionEnd happens before onJobEnd(JobSucceeded)") { (store, bus) => + test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { + val statusStore = createStatusStore() + val listener = statusStore.listener.get + val executionId = 0 val df = createTestDataFrame - bus.postToAll(SparkListenerSQLExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - bus.postToAll(SparkListenerJobStart( + listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - bus.postToAll(SparkListenerSQLExecutionEnd( + listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - bus.postToAll(SparkListenerJobEnd( + listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - assertJobs(store.execution(0), completed = Seq(0)) + assertJobs(statusStore.execution(executionId), completed = Seq(0)) } - sqlStoreTest("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { (store, bus) => + test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { + val statusStore = createStatusStore() + val listener = statusStore.listener.get + val executionId = 0 val df = createTestDataFrame - bus.postToAll(SparkListenerSQLExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - bus.postToAll(SparkListenerJobStart( + listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - bus.postToAll(SparkListenerJobEnd( + listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobSucceeded )) - bus.postToAll(SparkListenerJobStart( + listener.onJobStart(SparkListenerJobStart( jobId = 1, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - bus.postToAll(SparkListenerSQLExecutionEnd( + listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - bus.postToAll(SparkListenerJobEnd( + listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), JobSucceeded )) - assertJobs(store.execution(0), completed = Seq(0, 1)) + assertJobs(statusStore.execution(executionId), completed = Seq(0, 1)) } - sqlStoreTest("onExecutionEnd happens before onJobEnd(JobFailed)") { (store, bus) => + test("onExecutionEnd happens before onJobEnd(JobFailed)") { + val statusStore = createStatusStore() + val listener = statusStore.listener.get + val executionId = 0 val df = createTestDataFrame - bus.postToAll(SparkListenerSQLExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), System.currentTimeMillis())) - bus.postToAll(SparkListenerJobStart( + listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - bus.postToAll(SparkListenerSQLExecutionEnd( + listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) - bus.postToAll(SparkListenerJobEnd( + listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), JobFailed(new RuntimeException("Oops")) )) - assertJobs(store.execution(0), failed = Seq(0)) + assertJobs(statusStore.execution(executionId), failed = Seq(0)) } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = statusStore.executionsList().size + val listener = spark.sharedState.statusStore.listener.get + // At the beginning of this test case, there should be no live data in the listener. + assert(listener.noLiveData()) spark.sparkContext.parallelize(1 to 10).foreach(i => ()) spark.sparkContext.listenerBus.waitUntilEmpty(10000) - // listener should ignore the non SQL stage - assert(statusStore.executionsList().size == previousStageNumber) - - spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) - spark.sparkContext.listenerBus.waitUntilEmpty(10000) - // listener should save the SQL stage - assert(statusStore.executionsList().size == previousStageNumber + 1) + // Listener should ignore the non-SQL stages, as the stage data are only removed when SQL + // execution ends, which will not be triggered for non-SQL jobs. + assert(listener.noLiveData()) } test("driver side SQL metrics") { + val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size - val expectedAccumValue = 12345L + + val expectedAccumValue = 12345 val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan @@ -387,7 +409,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest physicalPlan.execute().collect() } - while (statusStore.executionsList().size < oldCount) { + // Wait until the new execution is started and being tracked. + while (statusStore.executionsCount() < oldCount) { Thread.sleep(100) } @@ -405,30 +428,6 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTest assert(metrics(driverMetric.id) === expectedValue) } - private def assertJobs( - exec: Option[SQLExecutionUIData], - running: Seq[Int] = Nil, - completed: Seq[Int] = Nil, - failed: Seq[Int] = Nil): Unit = { - - val actualRunning = new ListBuffer[Int]() - val actualCompleted = new ListBuffer[Int]() - val actualFailed = new ListBuffer[Int]() - - exec.get.jobs.foreach { case (jobId, jobStatus) => - jobStatus match { - case JobExecutionStatus.RUNNING => actualRunning += jobId - case JobExecutionStatus.SUCCEEDED => actualCompleted += jobId - case JobExecutionStatus.FAILED => actualFailed += jobId - case _ => fail(s"Unexpected status $jobStatus") - } - } - - assert(actualRunning.toSeq.sorted === running) - assert(actualCompleted.toSeq.sorted === completed) - assert(actualFailed.toSeq.sorted === failed) - } - test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { val event = SparkListenerDriverAccumUpdates(1L, Seq((2L, 3L))) val json = JsonProtocol.sparkEventToJson(event) @@ -494,7 +493,7 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe } -class SQLListenerMemoryLeakSuite extends SparkFunSuite { +class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { val conf = new SparkConf() @@ -522,9 +521,10 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - - val statusStore = new SQLAppStatusStore(sc.statusStore.store) - assert(statusStore.executionsList().size <= 50) + val statusStore = spark.sharedState.statusStore + assert(statusStore.executionsCount() <= 50) + // No live data should be left behind after all executions end. + assert(statusStore.listener.get.noLiveData()) } } } From 4e107fdb7463a67d9c77c4a3434dfe70c72982f4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 21 Dec 2017 09:18:27 -0800 Subject: [PATCH 477/712] [SPARK-22822][TEST] Basic tests for WindowFrameCoercion and DecimalPrecision ## What changes were proposed in this pull request? Test Coverage for `WindowFrameCoercion` and `DecimalPrecision`, this is a Sub-tasks for [SPARK-22722](https://issues.apache.org/jira/browse/SPARK-22722). ## How was this patch tested? N/A Author: Yuming Wang Closes #20008 from wangyum/SPARK-22822. --- .../expressions/windowExpressions.scala | 4 +- .../typeCoercion/native/decimalPrecision.sql | 1448 +++ .../native/windowFrameCoercion.sql | 44 + .../native/decimalPrecision.sql.out | 9514 +++++++++++++++++ .../native/windowFrameCoercion.sql.out | 206 + .../sql-tests/results/window.sql.out | 2 +- 6 files changed, 11215 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index e11e3a105f597..220cc4f885d7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -251,8 +251,8 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType} does not match " + - s"the expected data type '${frameType.inputType}'.") + s"The data type of the $location bound '${e.dataType}' does not match " + + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql new file mode 100644 index 0000000000000..8b04864b18ce3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalPrecision.sql @@ -0,0 +1,1448 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT cast(1 as tinyint) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as smallint), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as smallint), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as smallint), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as int), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as int), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as int), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as int), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as bigint), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as bigint), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as bigint), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as float), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as float), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as float), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as float), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as double), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as double), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as double), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as double), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast('1' as binary), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast('1' as binary), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast('1' as binary), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast('1' as binary), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(3, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(5, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(20, 0))) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as tinyint)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as tinyint)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as tinyint)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as smallint)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as smallint)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as smallint)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as int)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as int)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as int)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as int)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as bigint)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as bigint)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as bigint)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as float)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as float)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as float)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as float)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as double)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as double)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as double)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as double)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as decimal(10, 0))) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as string)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as string)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as string)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as string)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast('1' as binary)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast('1' as binary)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast('1' as binary)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as boolean)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as boolean)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as boolean)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; + +SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t; +SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t; +SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t; +SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT cast(1 as tinyint) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t; + +SELECT cast(1 as tinyint) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as tinyint) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as tinyint) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as tinyint) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as smallint) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as smallint) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as smallint) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as smallint) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as int) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as int) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as int) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as int) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as bigint) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as bigint) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as bigint) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as bigint) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as float) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as float) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as float) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as float) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as double) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as double) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as double) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as double) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('1' as binary) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('1' as binary) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('1' as binary) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('1' as binary) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(3, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(5, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(20, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as tinyint) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as tinyint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as smallint) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as smallint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as int) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as int) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as int) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as int) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as bigint) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as bigint) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as float) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as float) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as float) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as float) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as double) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as double) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as double) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as double) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as decimal(10, 0)) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as string) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as string) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as string) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as string) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast('1' as binary) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast('1' as binary) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast(1 as boolean) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast(1 as boolean) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t; + +SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t; +SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql new file mode 100644 index 0000000000000..5cd3538757499 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/windowFrameCoercion.sql @@ -0,0 +1,44 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0))) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp)) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date)) FROM t; + +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0)) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out new file mode 100644 index 0000000000000..ebc8201ed5a1d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -0,0 +1,9514 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 1145 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT cast(1 as tinyint) + cast(1 as decimal(3, 0)) FROM t +-- !query 1 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) + CAST(1 AS DECIMAL(3,0))):decimal(4,0)> +-- !query 1 output +2 + + +-- !query 2 +SELECT cast(1 as tinyint) + cast(1 as decimal(5, 0)) FROM t +-- !query 2 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 2 output +2 + + +-- !query 3 +SELECT cast(1 as tinyint) + cast(1 as decimal(10, 0)) FROM t +-- !query 3 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 3 output +2 + + +-- !query 4 +SELECT cast(1 as tinyint) + cast(1 as decimal(20, 0)) FROM t +-- !query 4 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 4 output +2 + + +-- !query 5 +SELECT cast(1 as smallint) + cast(1 as decimal(3, 0)) FROM t +-- !query 5 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 5 output +2 + + +-- !query 6 +SELECT cast(1 as smallint) + cast(1 as decimal(5, 0)) FROM t +-- !query 6 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) + CAST(1 AS DECIMAL(5,0))):decimal(6,0)> +-- !query 6 output +2 + + +-- !query 7 +SELECT cast(1 as smallint) + cast(1 as decimal(10, 0)) FROM t +-- !query 7 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 7 output +2 + + +-- !query 8 +SELECT cast(1 as smallint) + cast(1 as decimal(20, 0)) FROM t +-- !query 8 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 8 output +2 + + +-- !query 9 +SELECT cast(1 as int) + cast(1 as decimal(3, 0)) FROM t +-- !query 9 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 9 output +2 + + +-- !query 10 +SELECT cast(1 as int) + cast(1 as decimal(5, 0)) FROM t +-- !query 10 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 10 output +2 + + +-- !query 11 +SELECT cast(1 as int) + cast(1 as decimal(10, 0)) FROM t +-- !query 11 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 11 output +2 + + +-- !query 12 +SELECT cast(1 as int) + cast(1 as decimal(20, 0)) FROM t +-- !query 12 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 12 output +2 + + +-- !query 13 +SELECT cast(1 as bigint) + cast(1 as decimal(3, 0)) FROM t +-- !query 13 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 13 output +2 + + +-- !query 14 +SELECT cast(1 as bigint) + cast(1 as decimal(5, 0)) FROM t +-- !query 14 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 14 output +2 + + +-- !query 15 +SELECT cast(1 as bigint) + cast(1 as decimal(10, 0)) FROM t +-- !query 15 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 15 output +2 + + +-- !query 16 +SELECT cast(1 as bigint) + cast(1 as decimal(20, 0)) FROM t +-- !query 16 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) + CAST(1 AS DECIMAL(20,0))):decimal(21,0)> +-- !query 16 output +2 + + +-- !query 17 +SELECT cast(1 as float) + cast(1 as decimal(3, 0)) FROM t +-- !query 17 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 17 output +2.0 + + +-- !query 18 +SELECT cast(1 as float) + cast(1 as decimal(5, 0)) FROM t +-- !query 18 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 18 output +2.0 + + +-- !query 19 +SELECT cast(1 as float) + cast(1 as decimal(10, 0)) FROM t +-- !query 19 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 19 output +2.0 + + +-- !query 20 +SELECT cast(1 as float) + cast(1 as decimal(20, 0)) FROM t +-- !query 20 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) + CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 20 output +2.0 + + +-- !query 21 +SELECT cast(1 as double) + cast(1 as decimal(3, 0)) FROM t +-- !query 21 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 21 output +2.0 + + +-- !query 22 +SELECT cast(1 as double) + cast(1 as decimal(5, 0)) FROM t +-- !query 22 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 22 output +2.0 + + +-- !query 23 +SELECT cast(1 as double) + cast(1 as decimal(10, 0)) FROM t +-- !query 23 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 23 output +2.0 + + +-- !query 24 +SELECT cast(1 as double) + cast(1 as decimal(20, 0)) FROM t +-- !query 24 schema +struct<(CAST(1 AS DOUBLE) + CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 24 output +2.0 + + +-- !query 25 +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(3, 0)) FROM t +-- !query 25 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 25 output +2 + + +-- !query 26 +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(5, 0)) FROM t +-- !query 26 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 26 output +2 + + +-- !query 27 +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t +-- !query 27 schema +struct<(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 27 output +2 + + +-- !query 28 +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(20, 0)) FROM t +-- !query 28 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 28 output +2 + + +-- !query 29 +SELECT cast('1' as binary) + cast(1 as decimal(3, 0)) FROM t +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 30 +SELECT cast('1' as binary) + cast(1 as decimal(5, 0)) FROM t +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 31 +SELECT cast('1' as binary) + cast(1 as decimal(10, 0)) FROM t +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 32 +SELECT cast('1' as binary) + cast(1 as decimal(20, 0)) FROM t +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 33 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(3, 0)) FROM t +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 34 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(5, 0)) FROM t +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 35 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(10, 0)) FROM t +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 36 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) + cast(1 as decimal(20, 0)) FROM t +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) + CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 37 +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(3, 0)) FROM t +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 38 +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(5, 0)) FROM t +-- !query 38 schema +struct<> +-- !query 38 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 39 +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(10, 0)) FROM t +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 40 +SELECT cast('2017-12-11 09:30:00' as date) + cast(1 as decimal(20, 0)) FROM t +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) + CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 41 +SELECT cast(1 as decimal(3, 0)) + cast(1 as tinyint) FROM t +-- !query 41 schema +struct<(CAST(1 AS DECIMAL(3,0)) + CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(4,0)> +-- !query 41 output +2 + + +-- !query 42 +SELECT cast(1 as decimal(5, 0)) + cast(1 as tinyint) FROM t +-- !query 42 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 42 output +2 + + +-- !query 43 +SELECT cast(1 as decimal(10, 0)) + cast(1 as tinyint) FROM t +-- !query 43 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 43 output +2 + + +-- !query 44 +SELECT cast(1 as decimal(20, 0)) + cast(1 as tinyint) FROM t +-- !query 44 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 44 output +2 + + +-- !query 45 +SELECT cast(1 as decimal(3, 0)) + cast(1 as smallint) FROM t +-- !query 45 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 45 output +2 + + +-- !query 46 +SELECT cast(1 as decimal(5, 0)) + cast(1 as smallint) FROM t +-- !query 46 schema +struct<(CAST(1 AS DECIMAL(5,0)) + CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(6,0)> +-- !query 46 output +2 + + +-- !query 47 +SELECT cast(1 as decimal(10, 0)) + cast(1 as smallint) FROM t +-- !query 47 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 47 output +2 + + +-- !query 48 +SELECT cast(1 as decimal(20, 0)) + cast(1 as smallint) FROM t +-- !query 48 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 48 output +2 + + +-- !query 49 +SELECT cast(1 as decimal(3, 0)) + cast(1 as int) FROM t +-- !query 49 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 49 output +2 + + +-- !query 50 +SELECT cast(1 as decimal(5, 0)) + cast(1 as int) FROM t +-- !query 50 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 50 output +2 + + +-- !query 51 +SELECT cast(1 as decimal(10, 0)) + cast(1 as int) FROM t +-- !query 51 schema +struct<(CAST(1 AS DECIMAL(10,0)) + CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(11,0)> +-- !query 51 output +2 + + +-- !query 52 +SELECT cast(1 as decimal(20, 0)) + cast(1 as int) FROM t +-- !query 52 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 52 output +2 + + +-- !query 53 +SELECT cast(1 as decimal(3, 0)) + cast(1 as bigint) FROM t +-- !query 53 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 53 output +2 + + +-- !query 54 +SELECT cast(1 as decimal(5, 0)) + cast(1 as bigint) FROM t +-- !query 54 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 54 output +2 + + +-- !query 55 +SELECT cast(1 as decimal(10, 0)) + cast(1 as bigint) FROM t +-- !query 55 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) + CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 55 output +2 + + +-- !query 56 +SELECT cast(1 as decimal(20, 0)) + cast(1 as bigint) FROM t +-- !query 56 schema +struct<(CAST(1 AS DECIMAL(20,0)) + CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(21,0)> +-- !query 56 output +2 + + +-- !query 57 +SELECT cast(1 as decimal(3, 0)) + cast(1 as float) FROM t +-- !query 57 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 57 output +2.0 + + +-- !query 58 +SELECT cast(1 as decimal(5, 0)) + cast(1 as float) FROM t +-- !query 58 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 58 output +2.0 + + +-- !query 59 +SELECT cast(1 as decimal(10, 0)) + cast(1 as float) FROM t +-- !query 59 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 59 output +2.0 + + +-- !query 60 +SELECT cast(1 as decimal(20, 0)) + cast(1 as float) FROM t +-- !query 60 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 60 output +2.0 + + +-- !query 61 +SELECT cast(1 as decimal(3, 0)) + cast(1 as double) FROM t +-- !query 61 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 61 output +2.0 + + +-- !query 62 +SELECT cast(1 as decimal(5, 0)) + cast(1 as double) FROM t +-- !query 62 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 62 output +2.0 + + +-- !query 63 +SELECT cast(1 as decimal(10, 0)) + cast(1 as double) FROM t +-- !query 63 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 63 output +2.0 + + +-- !query 64 +SELECT cast(1 as decimal(20, 0)) + cast(1 as double) FROM t +-- !query 64 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(1 AS DOUBLE)):double> +-- !query 64 output +2.0 + + +-- !query 65 +SELECT cast(1 as decimal(3, 0)) + cast(1 as decimal(10, 0)) FROM t +-- !query 65 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 65 output +2 + + +-- !query 66 +SELECT cast(1 as decimal(5, 0)) + cast(1 as decimal(10, 0)) FROM t +-- !query 66 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 66 output +2 + + +-- !query 67 +SELECT cast(1 as decimal(10, 0)) + cast(1 as decimal(10, 0)) FROM t +-- !query 67 schema +struct<(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 67 output +2 + + +-- !query 68 +SELECT cast(1 as decimal(20, 0)) + cast(1 as decimal(10, 0)) FROM t +-- !query 68 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) + CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 68 output +2 + + +-- !query 69 +SELECT cast(1 as decimal(3, 0)) + cast(1 as string) FROM t +-- !query 69 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 69 output +2.0 + + +-- !query 70 +SELECT cast(1 as decimal(5, 0)) + cast(1 as string) FROM t +-- !query 70 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 70 output +2.0 + + +-- !query 71 +SELECT cast(1 as decimal(10, 0)) + cast(1 as string) FROM t +-- !query 71 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 71 output +2.0 + + +-- !query 72 +SELECT cast(1 as decimal(20, 0)) + cast(1 as string) FROM t +-- !query 72 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) + CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 72 output +2.0 + + +-- !query 73 +SELECT cast(1 as decimal(3, 0)) + cast('1' as binary) FROM t +-- !query 73 schema +struct<> +-- !query 73 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 74 +SELECT cast(1 as decimal(5, 0)) + cast('1' as binary) FROM t +-- !query 74 schema +struct<> +-- !query 74 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 75 +SELECT cast(1 as decimal(10, 0)) + cast('1' as binary) FROM t +-- !query 75 schema +struct<> +-- !query 75 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 76 +SELECT cast(1 as decimal(20, 0)) + cast('1' as binary) FROM t +-- !query 76 schema +struct<> +-- !query 76 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 77 +SELECT cast(1 as decimal(3, 0)) + cast(1 as boolean) FROM t +-- !query 77 schema +struct<> +-- !query 77 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 78 +SELECT cast(1 as decimal(5, 0)) + cast(1 as boolean) FROM t +-- !query 78 schema +struct<> +-- !query 78 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 79 +SELECT cast(1 as decimal(10, 0)) + cast(1 as boolean) FROM t +-- !query 79 schema +struct<> +-- !query 79 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 80 +SELECT cast(1 as decimal(20, 0)) + cast(1 as boolean) FROM t +-- !query 80 schema +struct<> +-- !query 80 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 81 +SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 81 schema +struct<> +-- !query 81 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 82 +SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 82 schema +struct<> +-- !query 82 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 83 +SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 83 schema +struct<> +-- !query 83 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 84 +SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 84 schema +struct<> +-- !query 84 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 85 +SELECT cast(1 as decimal(3, 0)) + cast('2017-12-11 09:30:00' as date) FROM t +-- !query 85 schema +struct<> +-- !query 85 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 86 +SELECT cast(1 as decimal(5, 0)) + cast('2017-12-11 09:30:00' as date) FROM t +-- !query 86 schema +struct<> +-- !query 86 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 87 +SELECT cast(1 as decimal(10, 0)) + cast('2017-12-11 09:30:00' as date) FROM t +-- !query 87 schema +struct<> +-- !query 87 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 88 +SELECT cast(1 as decimal(20, 0)) + cast('2017-12-11 09:30:00' as date) FROM t +-- !query 88 schema +struct<> +-- !query 88 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) + CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 89 +SELECT cast(1 as tinyint) - cast(1 as decimal(3, 0)) FROM t +-- !query 89 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) - CAST(1 AS DECIMAL(3,0))):decimal(4,0)> +-- !query 89 output +0 + + +-- !query 90 +SELECT cast(1 as tinyint) - cast(1 as decimal(5, 0)) FROM t +-- !query 90 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 90 output +0 + + +-- !query 91 +SELECT cast(1 as tinyint) - cast(1 as decimal(10, 0)) FROM t +-- !query 91 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 91 output +0 + + +-- !query 92 +SELECT cast(1 as tinyint) - cast(1 as decimal(20, 0)) FROM t +-- !query 92 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 92 output +0 + + +-- !query 93 +SELECT cast(1 as smallint) - cast(1 as decimal(3, 0)) FROM t +-- !query 93 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 93 output +0 + + +-- !query 94 +SELECT cast(1 as smallint) - cast(1 as decimal(5, 0)) FROM t +-- !query 94 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) - CAST(1 AS DECIMAL(5,0))):decimal(6,0)> +-- !query 94 output +0 + + +-- !query 95 +SELECT cast(1 as smallint) - cast(1 as decimal(10, 0)) FROM t +-- !query 95 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 95 output +0 + + +-- !query 96 +SELECT cast(1 as smallint) - cast(1 as decimal(20, 0)) FROM t +-- !query 96 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 96 output +0 + + +-- !query 97 +SELECT cast(1 as int) - cast(1 as decimal(3, 0)) FROM t +-- !query 97 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 97 output +0 + + +-- !query 98 +SELECT cast(1 as int) - cast(1 as decimal(5, 0)) FROM t +-- !query 98 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 98 output +0 + + +-- !query 99 +SELECT cast(1 as int) - cast(1 as decimal(10, 0)) FROM t +-- !query 99 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 99 output +0 + + +-- !query 100 +SELECT cast(1 as int) - cast(1 as decimal(20, 0)) FROM t +-- !query 100 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 100 output +0 + + +-- !query 101 +SELECT cast(1 as bigint) - cast(1 as decimal(3, 0)) FROM t +-- !query 101 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 101 output +0 + + +-- !query 102 +SELECT cast(1 as bigint) - cast(1 as decimal(5, 0)) FROM t +-- !query 102 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 102 output +0 + + +-- !query 103 +SELECT cast(1 as bigint) - cast(1 as decimal(10, 0)) FROM t +-- !query 103 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 103 output +0 + + +-- !query 104 +SELECT cast(1 as bigint) - cast(1 as decimal(20, 0)) FROM t +-- !query 104 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) - CAST(1 AS DECIMAL(20,0))):decimal(21,0)> +-- !query 104 output +0 + + +-- !query 105 +SELECT cast(1 as float) - cast(1 as decimal(3, 0)) FROM t +-- !query 105 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 105 output +0.0 + + +-- !query 106 +SELECT cast(1 as float) - cast(1 as decimal(5, 0)) FROM t +-- !query 106 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 106 output +0.0 + + +-- !query 107 +SELECT cast(1 as float) - cast(1 as decimal(10, 0)) FROM t +-- !query 107 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 107 output +0.0 + + +-- !query 108 +SELECT cast(1 as float) - cast(1 as decimal(20, 0)) FROM t +-- !query 108 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) - CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 108 output +0.0 + + +-- !query 109 +SELECT cast(1 as double) - cast(1 as decimal(3, 0)) FROM t +-- !query 109 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 109 output +0.0 + + +-- !query 110 +SELECT cast(1 as double) - cast(1 as decimal(5, 0)) FROM t +-- !query 110 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 110 output +0.0 + + +-- !query 111 +SELECT cast(1 as double) - cast(1 as decimal(10, 0)) FROM t +-- !query 111 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 111 output +0.0 + + +-- !query 112 +SELECT cast(1 as double) - cast(1 as decimal(20, 0)) FROM t +-- !query 112 schema +struct<(CAST(1 AS DOUBLE) - CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 112 output +0.0 + + +-- !query 113 +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(3, 0)) FROM t +-- !query 113 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 113 output +0 + + +-- !query 114 +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(5, 0)) FROM t +-- !query 114 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 114 output +0 + + +-- !query 115 +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t +-- !query 115 schema +struct<(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 115 output +0 + + +-- !query 116 +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(20, 0)) FROM t +-- !query 116 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 116 output +0 + + +-- !query 117 +SELECT cast('1' as binary) - cast(1 as decimal(3, 0)) FROM t +-- !query 117 schema +struct<> +-- !query 117 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 118 +SELECT cast('1' as binary) - cast(1 as decimal(5, 0)) FROM t +-- !query 118 schema +struct<> +-- !query 118 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 119 +SELECT cast('1' as binary) - cast(1 as decimal(10, 0)) FROM t +-- !query 119 schema +struct<> +-- !query 119 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 120 +SELECT cast('1' as binary) - cast(1 as decimal(20, 0)) FROM t +-- !query 120 schema +struct<> +-- !query 120 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 121 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(3, 0)) FROM t +-- !query 121 schema +struct<> +-- !query 121 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 122 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(5, 0)) FROM t +-- !query 122 schema +struct<> +-- !query 122 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 123 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(10, 0)) FROM t +-- !query 123 schema +struct<> +-- !query 123 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 124 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) - cast(1 as decimal(20, 0)) FROM t +-- !query 124 schema +struct<> +-- !query 124 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) - CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 125 +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(3, 0)) FROM t +-- !query 125 schema +struct<> +-- !query 125 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 126 +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(5, 0)) FROM t +-- !query 126 schema +struct<> +-- !query 126 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 127 +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(10, 0)) FROM t +-- !query 127 schema +struct<> +-- !query 127 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 128 +SELECT cast('2017-12-11 09:30:00' as date) - cast(1 as decimal(20, 0)) FROM t +-- !query 128 schema +struct<> +-- !query 128 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) - CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 129 +SELECT cast(1 as decimal(3, 0)) - cast(1 as tinyint) FROM t +-- !query 129 schema +struct<(CAST(1 AS DECIMAL(3,0)) - CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(4,0)> +-- !query 129 output +0 + + +-- !query 130 +SELECT cast(1 as decimal(5, 0)) - cast(1 as tinyint) FROM t +-- !query 130 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(6,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 130 output +0 + + +-- !query 131 +SELECT cast(1 as decimal(10, 0)) - cast(1 as tinyint) FROM t +-- !query 131 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 131 output +0 + + +-- !query 132 +SELECT cast(1 as decimal(20, 0)) - cast(1 as tinyint) FROM t +-- !query 132 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 132 output +0 + + +-- !query 133 +SELECT cast(1 as decimal(3, 0)) - cast(1 as smallint) FROM t +-- !query 133 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(6,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(6,0))):decimal(6,0)> +-- !query 133 output +0 + + +-- !query 134 +SELECT cast(1 as decimal(5, 0)) - cast(1 as smallint) FROM t +-- !query 134 schema +struct<(CAST(1 AS DECIMAL(5,0)) - CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(6,0)> +-- !query 134 output +0 + + +-- !query 135 +SELECT cast(1 as decimal(10, 0)) - cast(1 as smallint) FROM t +-- !query 135 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 135 output +0 + + +-- !query 136 +SELECT cast(1 as decimal(20, 0)) - cast(1 as smallint) FROM t +-- !query 136 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 136 output +0 + + +-- !query 137 +SELECT cast(1 as decimal(3, 0)) - cast(1 as int) FROM t +-- !query 137 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 137 output +0 + + +-- !query 138 +SELECT cast(1 as decimal(5, 0)) - cast(1 as int) FROM t +-- !query 138 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 138 output +0 + + +-- !query 139 +SELECT cast(1 as decimal(10, 0)) - cast(1 as int) FROM t +-- !query 139 schema +struct<(CAST(1 AS DECIMAL(10,0)) - CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(11,0)> +-- !query 139 output +0 + + +-- !query 140 +SELECT cast(1 as decimal(20, 0)) - cast(1 as int) FROM t +-- !query 140 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 140 output +0 + + +-- !query 141 +SELECT cast(1 as decimal(3, 0)) - cast(1 as bigint) FROM t +-- !query 141 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 141 output +0 + + +-- !query 142 +SELECT cast(1 as decimal(5, 0)) - cast(1 as bigint) FROM t +-- !query 142 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 142 output +0 + + +-- !query 143 +SELECT cast(1 as decimal(10, 0)) - cast(1 as bigint) FROM t +-- !query 143 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0)) - CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 143 output +0 + + +-- !query 144 +SELECT cast(1 as decimal(20, 0)) - cast(1 as bigint) FROM t +-- !query 144 schema +struct<(CAST(1 AS DECIMAL(20,0)) - CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(21,0)> +-- !query 144 output +0 + + +-- !query 145 +SELECT cast(1 as decimal(3, 0)) - cast(1 as float) FROM t +-- !query 145 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 145 output +0.0 + + +-- !query 146 +SELECT cast(1 as decimal(5, 0)) - cast(1 as float) FROM t +-- !query 146 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 146 output +0.0 + + +-- !query 147 +SELECT cast(1 as decimal(10, 0)) - cast(1 as float) FROM t +-- !query 147 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 147 output +0.0 + + +-- !query 148 +SELECT cast(1 as decimal(20, 0)) - cast(1 as float) FROM t +-- !query 148 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 148 output +0.0 + + +-- !query 149 +SELECT cast(1 as decimal(3, 0)) - cast(1 as double) FROM t +-- !query 149 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 149 output +0.0 + + +-- !query 150 +SELECT cast(1 as decimal(5, 0)) - cast(1 as double) FROM t +-- !query 150 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 150 output +0.0 + + +-- !query 151 +SELECT cast(1 as decimal(10, 0)) - cast(1 as double) FROM t +-- !query 151 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 151 output +0.0 + + +-- !query 152 +SELECT cast(1 as decimal(20, 0)) - cast(1 as double) FROM t +-- !query 152 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(1 AS DOUBLE)):double> +-- !query 152 output +0.0 + + +-- !query 153 +SELECT cast(1 as decimal(3, 0)) - cast(1 as decimal(10, 0)) FROM t +-- !query 153 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 153 output +0 + + +-- !query 154 +SELECT cast(1 as decimal(5, 0)) - cast(1 as decimal(10, 0)) FROM t +-- !query 154 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(11,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(11,0))):decimal(11,0)> +-- !query 154 output +0 + + +-- !query 155 +SELECT cast(1 as decimal(10, 0)) - cast(1 as decimal(10, 0)) FROM t +-- !query 155 schema +struct<(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS DECIMAL(10,0))):decimal(11,0)> +-- !query 155 output +0 + + +-- !query 156 +SELECT cast(1 as decimal(20, 0)) - cast(1 as decimal(10, 0)) FROM t +-- !query 156 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(21,0)) - CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(21,0))):decimal(21,0)> +-- !query 156 output +0 + + +-- !query 157 +SELECT cast(1 as decimal(3, 0)) - cast(1 as string) FROM t +-- !query 157 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 157 output +0.0 + + +-- !query 158 +SELECT cast(1 as decimal(5, 0)) - cast(1 as string) FROM t +-- !query 158 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 158 output +0.0 + + +-- !query 159 +SELECT cast(1 as decimal(10, 0)) - cast(1 as string) FROM t +-- !query 159 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 159 output +0.0 + + +-- !query 160 +SELECT cast(1 as decimal(20, 0)) - cast(1 as string) FROM t +-- !query 160 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) - CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 160 output +0.0 + + +-- !query 161 +SELECT cast(1 as decimal(3, 0)) - cast('1' as binary) FROM t +-- !query 161 schema +struct<> +-- !query 161 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 162 +SELECT cast(1 as decimal(5, 0)) - cast('1' as binary) FROM t +-- !query 162 schema +struct<> +-- !query 162 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 163 +SELECT cast(1 as decimal(10, 0)) - cast('1' as binary) FROM t +-- !query 163 schema +struct<> +-- !query 163 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 164 +SELECT cast(1 as decimal(20, 0)) - cast('1' as binary) FROM t +-- !query 164 schema +struct<> +-- !query 164 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 165 +SELECT cast(1 as decimal(3, 0)) - cast(1 as boolean) FROM t +-- !query 165 schema +struct<> +-- !query 165 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 166 +SELECT cast(1 as decimal(5, 0)) - cast(1 as boolean) FROM t +-- !query 166 schema +struct<> +-- !query 166 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 167 +SELECT cast(1 as decimal(10, 0)) - cast(1 as boolean) FROM t +-- !query 167 schema +struct<> +-- !query 167 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 168 +SELECT cast(1 as decimal(20, 0)) - cast(1 as boolean) FROM t +-- !query 168 schema +struct<> +-- !query 168 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 169 +SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 169 schema +struct<> +-- !query 169 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 170 +SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 170 schema +struct<> +-- !query 170 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 171 +SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 171 schema +struct<> +-- !query 171 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 172 +SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 172 schema +struct<> +-- !query 172 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 173 +SELECT cast(1 as decimal(3, 0)) - cast('2017-12-11 09:30:00' as date) FROM t +-- !query 173 schema +struct<> +-- !query 173 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 174 +SELECT cast(1 as decimal(5, 0)) - cast('2017-12-11 09:30:00' as date) FROM t +-- !query 174 schema +struct<> +-- !query 174 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 175 +SELECT cast(1 as decimal(10, 0)) - cast('2017-12-11 09:30:00' as date) FROM t +-- !query 175 schema +struct<> +-- !query 175 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 176 +SELECT cast(1 as decimal(20, 0)) - cast('2017-12-11 09:30:00' as date) FROM t +-- !query 176 schema +struct<> +-- !query 176 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) - CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 177 +SELECT cast(1 as tinyint) * cast(1 as decimal(3, 0)) FROM t +-- !query 177 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) * CAST(1 AS DECIMAL(3,0))):decimal(7,0)> +-- !query 177 output +1 + + +-- !query 178 +SELECT cast(1 as tinyint) * cast(1 as decimal(5, 0)) FROM t +-- !query 178 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,0)> +-- !query 178 output +1 + + +-- !query 179 +SELECT cast(1 as tinyint) * cast(1 as decimal(10, 0)) FROM t +-- !query 179 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 179 output +1 + + +-- !query 180 +SELECT cast(1 as tinyint) * cast(1 as decimal(20, 0)) FROM t +-- !query 180 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,0)> +-- !query 180 output +1 + + +-- !query 181 +SELECT cast(1 as smallint) * cast(1 as decimal(3, 0)) FROM t +-- !query 181 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(9,0)> +-- !query 181 output +1 + + +-- !query 182 +SELECT cast(1 as smallint) * cast(1 as decimal(5, 0)) FROM t +-- !query 182 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) * CAST(1 AS DECIMAL(5,0))):decimal(11,0)> +-- !query 182 output +1 + + +-- !query 183 +SELECT cast(1 as smallint) * cast(1 as decimal(10, 0)) FROM t +-- !query 183 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 183 output +1 + + +-- !query 184 +SELECT cast(1 as smallint) * cast(1 as decimal(20, 0)) FROM t +-- !query 184 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,0)> +-- !query 184 output +1 + + +-- !query 185 +SELECT cast(1 as int) * cast(1 as decimal(3, 0)) FROM t +-- !query 185 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 185 output +1 + + +-- !query 186 +SELECT cast(1 as int) * cast(1 as decimal(5, 0)) FROM t +-- !query 186 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 186 output +1 + + +-- !query 187 +SELECT cast(1 as int) * cast(1 as decimal(10, 0)) FROM t +-- !query 187 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)> +-- !query 187 output +1 + + +-- !query 188 +SELECT cast(1 as int) * cast(1 as decimal(20, 0)) FROM t +-- !query 188 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 188 output +1 + + +-- !query 189 +SELECT cast(1 as bigint) * cast(1 as decimal(3, 0)) FROM t +-- !query 189 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(24,0)> +-- !query 189 output +1 + + +-- !query 190 +SELECT cast(1 as bigint) * cast(1 as decimal(5, 0)) FROM t +-- !query 190 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,0)> +-- !query 190 output +1 + + +-- !query 191 +SELECT cast(1 as bigint) * cast(1 as decimal(10, 0)) FROM t +-- !query 191 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 191 output +1 + + +-- !query 192 +SELECT cast(1 as bigint) * cast(1 as decimal(20, 0)) FROM t +-- !query 192 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) * CAST(1 AS DECIMAL(20,0))):decimal(38,0)> +-- !query 192 output +1 + + +-- !query 193 +SELECT cast(1 as float) * cast(1 as decimal(3, 0)) FROM t +-- !query 193 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 193 output +1.0 + + +-- !query 194 +SELECT cast(1 as float) * cast(1 as decimal(5, 0)) FROM t +-- !query 194 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 194 output +1.0 + + +-- !query 195 +SELECT cast(1 as float) * cast(1 as decimal(10, 0)) FROM t +-- !query 195 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 195 output +1.0 + + +-- !query 196 +SELECT cast(1 as float) * cast(1 as decimal(20, 0)) FROM t +-- !query 196 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) * CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 196 output +1.0 + + +-- !query 197 +SELECT cast(1 as double) * cast(1 as decimal(3, 0)) FROM t +-- !query 197 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 197 output +1.0 + + +-- !query 198 +SELECT cast(1 as double) * cast(1 as decimal(5, 0)) FROM t +-- !query 198 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 198 output +1.0 + + +-- !query 199 +SELECT cast(1 as double) * cast(1 as decimal(10, 0)) FROM t +-- !query 199 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 199 output +1.0 + + +-- !query 200 +SELECT cast(1 as double) * cast(1 as decimal(20, 0)) FROM t +-- !query 200 schema +struct<(CAST(1 AS DOUBLE) * CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 200 output +1.0 + + +-- !query 201 +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(3, 0)) FROM t +-- !query 201 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 201 output +1 + + +-- !query 202 +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(5, 0)) FROM t +-- !query 202 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 202 output +1 + + +-- !query 203 +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t +-- !query 203 schema +struct<(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)> +-- !query 203 output +1 + + +-- !query 204 +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(20, 0)) FROM t +-- !query 204 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 204 output +1 + + +-- !query 205 +SELECT cast('1' as binary) * cast(1 as decimal(3, 0)) FROM t +-- !query 205 schema +struct<> +-- !query 205 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 206 +SELECT cast('1' as binary) * cast(1 as decimal(5, 0)) FROM t +-- !query 206 schema +struct<> +-- !query 206 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 207 +SELECT cast('1' as binary) * cast(1 as decimal(10, 0)) FROM t +-- !query 207 schema +struct<> +-- !query 207 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 208 +SELECT cast('1' as binary) * cast(1 as decimal(20, 0)) FROM t +-- !query 208 schema +struct<> +-- !query 208 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) * CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 209 +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(3, 0)) FROM t +-- !query 209 schema +struct<> +-- !query 209 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 210 +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(5, 0)) FROM t +-- !query 210 schema +struct<> +-- !query 210 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 211 +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(10, 0)) FROM t +-- !query 211 schema +struct<> +-- !query 211 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 212 +SELECT cast('2017*12*11 09:30:00.0' as timestamp) * cast(1 as decimal(20, 0)) FROM t +-- !query 212 schema +struct<> +-- !query 212 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00.0' AS TIMESTAMP) * CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 213 +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(3, 0)) FROM t +-- !query 213 schema +struct<> +-- !query 213 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 214 +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(5, 0)) FROM t +-- !query 214 schema +struct<> +-- !query 214 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 215 +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(10, 0)) FROM t +-- !query 215 schema +struct<> +-- !query 215 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 216 +SELECT cast('2017*12*11 09:30:00' as date) * cast(1 as decimal(20, 0)) FROM t +-- !query 216 schema +struct<> +-- !query 216 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017*12*11 09:30:00' AS DATE) * CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 217 +SELECT cast(1 as decimal(3, 0)) * cast(1 as tinyint) FROM t +-- !query 217 schema +struct<(CAST(1 AS DECIMAL(3,0)) * CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(7,0)> +-- !query 217 output +1 + + +-- !query 218 +SELECT cast(1 as decimal(5, 0)) * cast(1 as tinyint) FROM t +-- !query 218 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(9,0)> +-- !query 218 output +1 + + +-- !query 219 +SELECT cast(1 as decimal(10, 0)) * cast(1 as tinyint) FROM t +-- !query 219 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 219 output +1 + + +-- !query 220 +SELECT cast(1 as decimal(20, 0)) * cast(1 as tinyint) FROM t +-- !query 220 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(24,0)> +-- !query 220 output +1 + + +-- !query 221 +SELECT cast(1 as decimal(3, 0)) * cast(1 as smallint) FROM t +-- !query 221 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,0)> +-- !query 221 output +1 + + +-- !query 222 +SELECT cast(1 as decimal(5, 0)) * cast(1 as smallint) FROM t +-- !query 222 schema +struct<(CAST(1 AS DECIMAL(5,0)) * CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(11,0)> +-- !query 222 output +1 + + +-- !query 223 +SELECT cast(1 as decimal(10, 0)) * cast(1 as smallint) FROM t +-- !query 223 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 223 output +1 + + +-- !query 224 +SELECT cast(1 as decimal(20, 0)) * cast(1 as smallint) FROM t +-- !query 224 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,0)> +-- !query 224 output +1 + + +-- !query 225 +SELECT cast(1 as decimal(3, 0)) * cast(1 as int) FROM t +-- !query 225 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 225 output +1 + + +-- !query 226 +SELECT cast(1 as decimal(5, 0)) * cast(1 as int) FROM t +-- !query 226 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 226 output +1 + + +-- !query 227 +SELECT cast(1 as decimal(10, 0)) * cast(1 as int) FROM t +-- !query 227 schema +struct<(CAST(1 AS DECIMAL(10,0)) * CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,0)> +-- !query 227 output +1 + + +-- !query 228 +SELECT cast(1 as decimal(20, 0)) * cast(1 as int) FROM t +-- !query 228 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 228 output +1 + + +-- !query 229 +SELECT cast(1 as decimal(3, 0)) * cast(1 as bigint) FROM t +-- !query 229 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,0)> +-- !query 229 output +1 + + +-- !query 230 +SELECT cast(1 as decimal(5, 0)) * cast(1 as bigint) FROM t +-- !query 230 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,0)> +-- !query 230 output +1 + + +-- !query 231 +SELECT cast(1 as decimal(10, 0)) * cast(1 as bigint) FROM t +-- !query 231 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) * CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 231 output +1 + + +-- !query 232 +SELECT cast(1 as decimal(20, 0)) * cast(1 as bigint) FROM t +-- !query 232 schema +struct<(CAST(1 AS DECIMAL(20,0)) * CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,0)> +-- !query 232 output +1 + + +-- !query 233 +SELECT cast(1 as decimal(3, 0)) * cast(1 as float) FROM t +-- !query 233 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 233 output +1.0 + + +-- !query 234 +SELECT cast(1 as decimal(5, 0)) * cast(1 as float) FROM t +-- !query 234 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 234 output +1.0 + + +-- !query 235 +SELECT cast(1 as decimal(10, 0)) * cast(1 as float) FROM t +-- !query 235 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 235 output +1.0 + + +-- !query 236 +SELECT cast(1 as decimal(20, 0)) * cast(1 as float) FROM t +-- !query 236 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 236 output +1.0 + + +-- !query 237 +SELECT cast(1 as decimal(3, 0)) * cast(1 as double) FROM t +-- !query 237 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 237 output +1.0 + + +-- !query 238 +SELECT cast(1 as decimal(5, 0)) * cast(1 as double) FROM t +-- !query 238 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 238 output +1.0 + + +-- !query 239 +SELECT cast(1 as decimal(10, 0)) * cast(1 as double) FROM t +-- !query 239 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 239 output +1.0 + + +-- !query 240 +SELECT cast(1 as decimal(20, 0)) * cast(1 as double) FROM t +-- !query 240 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(1 AS DOUBLE)):double> +-- !query 240 output +1.0 + + +-- !query 241 +SELECT cast(1 as decimal(3, 0)) * cast(1 as decimal(10, 0)) FROM t +-- !query 241 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,0)> +-- !query 241 output +1 + + +-- !query 242 +SELECT cast(1 as decimal(5, 0)) * cast(1 as decimal(10, 0)) FROM t +-- !query 242 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,0)> +-- !query 242 output +1 + + +-- !query 243 +SELECT cast(1 as decimal(10, 0)) * cast(1 as decimal(10, 0)) FROM t +-- !query 243 schema +struct<(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS DECIMAL(10,0))):decimal(21,0)> +-- !query 243 output +1 + + +-- !query 244 +SELECT cast(1 as decimal(20, 0)) * cast(1 as decimal(10, 0)) FROM t +-- !query 244 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) * CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,0)> +-- !query 244 output +1 + + +-- !query 245 +SELECT cast(1 as decimal(3, 0)) * cast(1 as string) FROM t +-- !query 245 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 245 output +1.0 + + +-- !query 246 +SELECT cast(1 as decimal(5, 0)) * cast(1 as string) FROM t +-- !query 246 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 246 output +1.0 + + +-- !query 247 +SELECT cast(1 as decimal(10, 0)) * cast(1 as string) FROM t +-- !query 247 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 247 output +1.0 + + +-- !query 248 +SELECT cast(1 as decimal(20, 0)) * cast(1 as string) FROM t +-- !query 248 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) * CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 248 output +1.0 + + +-- !query 249 +SELECT cast(1 as decimal(3, 0)) * cast('1' as binary) FROM t +-- !query 249 schema +struct<> +-- !query 249 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 250 +SELECT cast(1 as decimal(5, 0)) * cast('1' as binary) FROM t +-- !query 250 schema +struct<> +-- !query 250 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 251 +SELECT cast(1 as decimal(10, 0)) * cast('1' as binary) FROM t +-- !query 251 schema +struct<> +-- !query 251 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 252 +SELECT cast(1 as decimal(20, 0)) * cast('1' as binary) FROM t +-- !query 252 schema +struct<> +-- !query 252 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 253 +SELECT cast(1 as decimal(3, 0)) * cast(1 as boolean) FROM t +-- !query 253 schema +struct<> +-- !query 253 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 254 +SELECT cast(1 as decimal(5, 0)) * cast(1 as boolean) FROM t +-- !query 254 schema +struct<> +-- !query 254 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 255 +SELECT cast(1 as decimal(10, 0)) * cast(1 as boolean) FROM t +-- !query 255 schema +struct<> +-- !query 255 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 256 +SELECT cast(1 as decimal(20, 0)) * cast(1 as boolean) FROM t +-- !query 256 schema +struct<> +-- !query 256 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 257 +SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t +-- !query 257 schema +struct<> +-- !query 257 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 258 +SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t +-- !query 258 schema +struct<> +-- !query 258 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 259 +SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t +-- !query 259 schema +struct<> +-- !query 259 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 260 +SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00.0' as timestamp) FROM t +-- !query 260 schema +struct<> +-- !query 260 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 261 +SELECT cast(1 as decimal(3, 0)) * cast('2017*12*11 09:30:00' as date) FROM t +-- !query 261 schema +struct<> +-- !query 261 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 262 +SELECT cast(1 as decimal(5, 0)) * cast('2017*12*11 09:30:00' as date) FROM t +-- !query 262 schema +struct<> +-- !query 262 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 263 +SELECT cast(1 as decimal(10, 0)) * cast('2017*12*11 09:30:00' as date) FROM t +-- !query 263 schema +struct<> +-- !query 263 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 264 +SELECT cast(1 as decimal(20, 0)) * cast('2017*12*11 09:30:00' as date) FROM t +-- !query 264 schema +struct<> +-- !query 264 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) * CAST('2017*12*11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 265 +SELECT cast(1 as tinyint) / cast(1 as decimal(3, 0)) FROM t +-- !query 265 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) / CAST(1 AS DECIMAL(3,0))):decimal(9,6)> +-- !query 265 output +1 + + +-- !query 266 +SELECT cast(1 as tinyint) / cast(1 as decimal(5, 0)) FROM t +-- !query 266 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,6)> +-- !query 266 output +1 + + +-- !query 267 +SELECT cast(1 as tinyint) / cast(1 as decimal(10, 0)) FROM t +-- !query 267 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)> +-- !query 267 output +1 + + +-- !query 268 +SELECT cast(1 as tinyint) / cast(1 as decimal(20, 0)) FROM t +-- !query 268 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,21)> +-- !query 268 output +1 + + +-- !query 269 +SELECT cast(1 as smallint) / cast(1 as decimal(3, 0)) FROM t +-- !query 269 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(11,6)> +-- !query 269 output +1 + + +-- !query 270 +SELECT cast(1 as smallint) / cast(1 as decimal(5, 0)) FROM t +-- !query 270 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) / CAST(1 AS DECIMAL(5,0))):decimal(11,6)> +-- !query 270 output +1 + + +-- !query 271 +SELECT cast(1 as smallint) / cast(1 as decimal(10, 0)) FROM t +-- !query 271 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)> +-- !query 271 output +1 + + +-- !query 272 +SELECT cast(1 as smallint) / cast(1 as decimal(20, 0)) FROM t +-- !query 272 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,21)> +-- !query 272 output +1 + + +-- !query 273 +SELECT cast(1 as int) / cast(1 as decimal(3, 0)) FROM t +-- !query 273 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 273 output +1 + + +-- !query 274 +SELECT cast(1 as int) / cast(1 as decimal(5, 0)) FROM t +-- !query 274 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 274 output +1 + + +-- !query 275 +SELECT cast(1 as int) / cast(1 as decimal(10, 0)) FROM t +-- !query 275 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)> +-- !query 275 output +1 + + +-- !query 276 +SELECT cast(1 as int) / cast(1 as decimal(20, 0)) FROM t +-- !query 276 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)> +-- !query 276 output +1 + + +-- !query 277 +SELECT cast(1 as bigint) / cast(1 as decimal(3, 0)) FROM t +-- !query 277 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(26,6)> +-- !query 277 output +1 + + +-- !query 278 +SELECT cast(1 as bigint) / cast(1 as decimal(5, 0)) FROM t +-- !query 278 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,6)> +-- !query 278 output +1 + + +-- !query 279 +SELECT cast(1 as bigint) / cast(1 as decimal(10, 0)) FROM t +-- !query 279 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)> +-- !query 279 output +1 + + +-- !query 280 +SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t +-- !query 280 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +-- !query 280 output +1 + + +-- !query 281 +SELECT cast(1 as float) / cast(1 as decimal(3, 0)) FROM t +-- !query 281 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) AS DOUBLE)):double> +-- !query 281 output +1.0 + + +-- !query 282 +SELECT cast(1 as float) / cast(1 as decimal(5, 0)) FROM t +-- !query 282 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) AS DOUBLE)):double> +-- !query 282 output +1.0 + + +-- !query 283 +SELECT cast(1 as float) / cast(1 as decimal(10, 0)) FROM t +-- !query 283 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) AS DOUBLE)):double> +-- !query 283 output +1.0 + + +-- !query 284 +SELECT cast(1 as float) / cast(1 as decimal(20, 0)) FROM t +-- !query 284 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) / CAST(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) AS DOUBLE)):double> +-- !query 284 output +1.0 + + +-- !query 285 +SELECT cast(1 as double) / cast(1 as decimal(3, 0)) FROM t +-- !query 285 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 285 output +1.0 + + +-- !query 286 +SELECT cast(1 as double) / cast(1 as decimal(5, 0)) FROM t +-- !query 286 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 286 output +1.0 + + +-- !query 287 +SELECT cast(1 as double) / cast(1 as decimal(10, 0)) FROM t +-- !query 287 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 287 output +1.0 + + +-- !query 288 +SELECT cast(1 as double) / cast(1 as decimal(20, 0)) FROM t +-- !query 288 schema +struct<(CAST(1 AS DOUBLE) / CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 288 output +1.0 + + +-- !query 289 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(3, 0)) FROM t +-- !query 289 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 289 output +1 + + +-- !query 290 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(5, 0)) FROM t +-- !query 290 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 290 output +1 + + +-- !query 291 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 291 schema +struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)> +-- !query 291 output +1 + + +-- !query 292 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(20, 0)) FROM t +-- !query 292 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)> +-- !query 292 output +1 + + +-- !query 293 +SELECT cast('1' as binary) / cast(1 as decimal(3, 0)) FROM t +-- !query 293 schema +struct<> +-- !query 293 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 294 +SELECT cast('1' as binary) / cast(1 as decimal(5, 0)) FROM t +-- !query 294 schema +struct<> +-- !query 294 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 295 +SELECT cast('1' as binary) / cast(1 as decimal(10, 0)) FROM t +-- !query 295 schema +struct<> +-- !query 295 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 296 +SELECT cast('1' as binary) / cast(1 as decimal(20, 0)) FROM t +-- !query 296 schema +struct<> +-- !query 296 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) / CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 297 +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(3, 0)) FROM t +-- !query 297 schema +struct<> +-- !query 297 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 298 +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(5, 0)) FROM t +-- !query 298 schema +struct<> +-- !query 298 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 299 +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(10, 0)) FROM t +-- !query 299 schema +struct<> +-- !query 299 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 300 +SELECT cast('2017/12/11 09:30:00.0' as timestamp) / cast(1 as decimal(20, 0)) FROM t +-- !query 300 schema +struct<> +-- !query 300 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00.0' AS TIMESTAMP) / CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 301 +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(3, 0)) FROM t +-- !query 301 schema +struct<> +-- !query 301 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 302 +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(5, 0)) FROM t +-- !query 302 schema +struct<> +-- !query 302 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 303 +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(10, 0)) FROM t +-- !query 303 schema +struct<> +-- !query 303 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 304 +SELECT cast('2017/12/11 09:30:00' as date) / cast(1 as decimal(20, 0)) FROM t +-- !query 304 schema +struct<> +-- !query 304 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017/12/11 09:30:00' AS DATE) / CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 305 +SELECT cast(1 as decimal(3, 0)) / cast(1 as tinyint) FROM t +-- !query 305 schema +struct<(CAST(1 AS DECIMAL(3,0)) / CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(9,6)> +-- !query 305 output +1 + + +-- !query 306 +SELECT cast(1 as decimal(5, 0)) / cast(1 as tinyint) FROM t +-- !query 306 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(11,6)> +-- !query 306 output +1 + + +-- !query 307 +SELECT cast(1 as decimal(10, 0)) / cast(1 as tinyint) FROM t +-- !query 307 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 307 output +1 + + +-- !query 308 +SELECT cast(1 as decimal(20, 0)) / cast(1 as tinyint) FROM t +-- !query 308 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(26,6)> +-- !query 308 output +1 + + +-- !query 309 +SELECT cast(1 as decimal(3, 0)) / cast(1 as smallint) FROM t +-- !query 309 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(9,6)> +-- !query 309 output +1 + + +-- !query 310 +SELECT cast(1 as decimal(5, 0)) / cast(1 as smallint) FROM t +-- !query 310 schema +struct<(CAST(1 AS DECIMAL(5,0)) / CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(11,6)> +-- !query 310 output +1 + + +-- !query 311 +SELECT cast(1 as decimal(10, 0)) / cast(1 as smallint) FROM t +-- !query 311 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(16,6)> +-- !query 311 output +1 + + +-- !query 312 +SELECT cast(1 as decimal(20, 0)) / cast(1 as smallint) FROM t +-- !query 312 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(26,6)> +-- !query 312 output +1 + + +-- !query 313 +SELECT cast(1 as decimal(3, 0)) / cast(1 as int) FROM t +-- !query 313 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)> +-- !query 313 output +1 + + +-- !query 314 +SELECT cast(1 as decimal(5, 0)) / cast(1 as int) FROM t +-- !query 314 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)> +-- !query 314 output +1 + + +-- !query 315 +SELECT cast(1 as decimal(10, 0)) / cast(1 as int) FROM t +-- !query 315 schema +struct<(CAST(1 AS DECIMAL(10,0)) / CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(21,11)> +-- !query 315 output +1 + + +-- !query 316 +SELECT cast(1 as decimal(20, 0)) / cast(1 as int) FROM t +-- !query 316 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)> +-- !query 316 output +1 + + +-- !query 317 +SELECT cast(1 as decimal(3, 0)) / cast(1 as bigint) FROM t +-- !query 317 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(24,21)> +-- !query 317 output +1 + + +-- !query 318 +SELECT cast(1 as decimal(5, 0)) / cast(1 as bigint) FROM t +-- !query 318 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(26,21)> +-- !query 318 output +1 + + +-- !query 319 +SELECT cast(1 as decimal(10, 0)) / cast(1 as bigint) FROM t +-- !query 319 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(31,21)> +-- !query 319 output +1 + + +-- !query 320 +SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t +-- !query 320 schema +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +-- !query 320 output +1 + + +-- !query 321 +SELECT cast(1 as decimal(3, 0)) / cast(1 as float) FROM t +-- !query 321 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 321 output +1.0 + + +-- !query 322 +SELECT cast(1 as decimal(5, 0)) / cast(1 as float) FROM t +-- !query 322 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 322 output +1.0 + + +-- !query 323 +SELECT cast(1 as decimal(10, 0)) / cast(1 as float) FROM t +-- !query 323 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 323 output +1.0 + + +-- !query 324 +SELECT cast(1 as decimal(20, 0)) / cast(1 as float) FROM t +-- !query 324 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 324 output +1.0 + + +-- !query 325 +SELECT cast(1 as decimal(3, 0)) / cast(1 as double) FROM t +-- !query 325 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 325 output +1.0 + + +-- !query 326 +SELECT cast(1 as decimal(5, 0)) / cast(1 as double) FROM t +-- !query 326 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 326 output +1.0 + + +-- !query 327 +SELECT cast(1 as decimal(10, 0)) / cast(1 as double) FROM t +-- !query 327 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 327 output +1.0 + + +-- !query 328 +SELECT cast(1 as decimal(20, 0)) / cast(1 as double) FROM t +-- !query 328 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(1 AS DOUBLE)):double> +-- !query 328 output +1.0 + + +-- !query 329 +SELECT cast(1 as decimal(3, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 329 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(14,11)> +-- !query 329 output +1 + + +-- !query 330 +SELECT cast(1 as decimal(5, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 330 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(16,11)> +-- !query 330 output +1 + + +-- !query 331 +SELECT cast(1 as decimal(10, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 331 schema +struct<(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS DECIMAL(10,0))):decimal(21,11)> +-- !query 331 output +1 + + +-- !query 332 +SELECT cast(1 as decimal(20, 0)) / cast(1 as decimal(10, 0)) FROM t +-- !query 332 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(31,11)> +-- !query 332 output +1 + + +-- !query 333 +SELECT cast(1 as decimal(3, 0)) / cast(1 as string) FROM t +-- !query 333 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 333 output +1.0 + + +-- !query 334 +SELECT cast(1 as decimal(5, 0)) / cast(1 as string) FROM t +-- !query 334 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 334 output +1.0 + + +-- !query 335 +SELECT cast(1 as decimal(10, 0)) / cast(1 as string) FROM t +-- !query 335 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 335 output +1.0 + + +-- !query 336 +SELECT cast(1 as decimal(20, 0)) / cast(1 as string) FROM t +-- !query 336 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) / CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 336 output +1.0 + + +-- !query 337 +SELECT cast(1 as decimal(3, 0)) / cast('1' as binary) FROM t +-- !query 337 schema +struct<> +-- !query 337 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 338 +SELECT cast(1 as decimal(5, 0)) / cast('1' as binary) FROM t +-- !query 338 schema +struct<> +-- !query 338 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 339 +SELECT cast(1 as decimal(10, 0)) / cast('1' as binary) FROM t +-- !query 339 schema +struct<> +-- !query 339 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 340 +SELECT cast(1 as decimal(20, 0)) / cast('1' as binary) FROM t +-- !query 340 schema +struct<> +-- !query 340 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 341 +SELECT cast(1 as decimal(3, 0)) / cast(1 as boolean) FROM t +-- !query 341 schema +struct<> +-- !query 341 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 342 +SELECT cast(1 as decimal(5, 0)) / cast(1 as boolean) FROM t +-- !query 342 schema +struct<> +-- !query 342 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 343 +SELECT cast(1 as decimal(10, 0)) / cast(1 as boolean) FROM t +-- !query 343 schema +struct<> +-- !query 343 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 344 +SELECT cast(1 as decimal(20, 0)) / cast(1 as boolean) FROM t +-- !query 344 schema +struct<> +-- !query 344 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 345 +SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t +-- !query 345 schema +struct<> +-- !query 345 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 346 +SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t +-- !query 346 schema +struct<> +-- !query 346 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 347 +SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t +-- !query 347 schema +struct<> +-- !query 347 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 348 +SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00.0' as timestamp) FROM t +-- !query 348 schema +struct<> +-- !query 348 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 349 +SELECT cast(1 as decimal(3, 0)) / cast('2017/12/11 09:30:00' as date) FROM t +-- !query 349 schema +struct<> +-- !query 349 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 350 +SELECT cast(1 as decimal(5, 0)) / cast('2017/12/11 09:30:00' as date) FROM t +-- !query 350 schema +struct<> +-- !query 350 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 351 +SELECT cast(1 as decimal(10, 0)) / cast('2017/12/11 09:30:00' as date) FROM t +-- !query 351 schema +struct<> +-- !query 351 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 352 +SELECT cast(1 as decimal(20, 0)) / cast('2017/12/11 09:30:00' as date) FROM t +-- !query 352 schema +struct<> +-- !query 352 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) / CAST('2017/12/11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 353 +SELECT cast(1 as tinyint) % cast(1 as decimal(3, 0)) FROM t +-- !query 353 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) % CAST(1 AS DECIMAL(3,0))):decimal(3,0)> +-- !query 353 output +0 + + +-- !query 354 +SELECT cast(1 as tinyint) % cast(1 as decimal(5, 0)) FROM t +-- !query 354 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(3,0)> +-- !query 354 output +0 + + +-- !query 355 +SELECT cast(1 as tinyint) % cast(1 as decimal(10, 0)) FROM t +-- !query 355 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 355 output +0 + + +-- !query 356 +SELECT cast(1 as tinyint) % cast(1 as decimal(20, 0)) FROM t +-- !query 356 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(3,0)> +-- !query 356 output +0 + + +-- !query 357 +SELECT cast(1 as smallint) % cast(1 as decimal(3, 0)) FROM t +-- !query 357 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(3,0)> +-- !query 357 output +0 + + +-- !query 358 +SELECT cast(1 as smallint) % cast(1 as decimal(5, 0)) FROM t +-- !query 358 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) % CAST(1 AS DECIMAL(5,0))):decimal(5,0)> +-- !query 358 output +0 + + +-- !query 359 +SELECT cast(1 as smallint) % cast(1 as decimal(10, 0)) FROM t +-- !query 359 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 359 output +0 + + +-- !query 360 +SELECT cast(1 as smallint) % cast(1 as decimal(20, 0)) FROM t +-- !query 360 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(5,0)> +-- !query 360 output +0 + + +-- !query 361 +SELECT cast(1 as int) % cast(1 as decimal(3, 0)) FROM t +-- !query 361 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 361 output +0 + + +-- !query 362 +SELECT cast(1 as int) % cast(1 as decimal(5, 0)) FROM t +-- !query 362 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 362 output +0 + + +-- !query 363 +SELECT cast(1 as int) % cast(1 as decimal(10, 0)) FROM t +-- !query 363 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)> +-- !query 363 output +0 + + +-- !query 364 +SELECT cast(1 as int) % cast(1 as decimal(20, 0)) FROM t +-- !query 364 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 364 output +0 + + +-- !query 365 +SELECT cast(1 as bigint) % cast(1 as decimal(3, 0)) FROM t +-- !query 365 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(3,0)> +-- !query 365 output +0 + + +-- !query 366 +SELECT cast(1 as bigint) % cast(1 as decimal(5, 0)) FROM t +-- !query 366 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(5,0)> +-- !query 366 output +0 + + +-- !query 367 +SELECT cast(1 as bigint) % cast(1 as decimal(10, 0)) FROM t +-- !query 367 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 367 output +0 + + +-- !query 368 +SELECT cast(1 as bigint) % cast(1 as decimal(20, 0)) FROM t +-- !query 368 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) % CAST(1 AS DECIMAL(20,0))):decimal(20,0)> +-- !query 368 output +0 + + +-- !query 369 +SELECT cast(1 as float) % cast(1 as decimal(3, 0)) FROM t +-- !query 369 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 369 output +0.0 + + +-- !query 370 +SELECT cast(1 as float) % cast(1 as decimal(5, 0)) FROM t +-- !query 370 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 370 output +0.0 + + +-- !query 371 +SELECT cast(1 as float) % cast(1 as decimal(10, 0)) FROM t +-- !query 371 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 371 output +0.0 + + +-- !query 372 +SELECT cast(1 as float) % cast(1 as decimal(20, 0)) FROM t +-- !query 372 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) % CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 372 output +0.0 + + +-- !query 373 +SELECT cast(1 as double) % cast(1 as decimal(3, 0)) FROM t +-- !query 373 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):double> +-- !query 373 output +0.0 + + +-- !query 374 +SELECT cast(1 as double) % cast(1 as decimal(5, 0)) FROM t +-- !query 374 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):double> +-- !query 374 output +0.0 + + +-- !query 375 +SELECT cast(1 as double) % cast(1 as decimal(10, 0)) FROM t +-- !query 375 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):double> +-- !query 375 output +0.0 + + +-- !query 376 +SELECT cast(1 as double) % cast(1 as decimal(20, 0)) FROM t +-- !query 376 schema +struct<(CAST(1 AS DOUBLE) % CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):double> +-- !query 376 output +0.0 + + +-- !query 377 +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(3, 0)) FROM t +-- !query 377 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 377 output +0 + + +-- !query 378 +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(5, 0)) FROM t +-- !query 378 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 378 output +0 + + +-- !query 379 +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t +-- !query 379 schema +struct<(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)> +-- !query 379 output +0 + + +-- !query 380 +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(20, 0)) FROM t +-- !query 380 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 380 output +0 + + +-- !query 381 +SELECT cast('1' as binary) % cast(1 as decimal(3, 0)) FROM t +-- !query 381 schema +struct<> +-- !query 381 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 382 +SELECT cast('1' as binary) % cast(1 as decimal(5, 0)) FROM t +-- !query 382 schema +struct<> +-- !query 382 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 383 +SELECT cast('1' as binary) % cast(1 as decimal(10, 0)) FROM t +-- !query 383 schema +struct<> +-- !query 383 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 384 +SELECT cast('1' as binary) % cast(1 as decimal(20, 0)) FROM t +-- !query 384 schema +struct<> +-- !query 384 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) % CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 385 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(3, 0)) FROM t +-- !query 385 schema +struct<> +-- !query 385 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 386 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(5, 0)) FROM t +-- !query 386 schema +struct<> +-- !query 386 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 387 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(10, 0)) FROM t +-- !query 387 schema +struct<> +-- !query 387 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 388 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) % cast(1 as decimal(20, 0)) FROM t +-- !query 388 schema +struct<> +-- !query 388 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) % CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 389 +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(3, 0)) FROM t +-- !query 389 schema +struct<> +-- !query 389 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 390 +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(5, 0)) FROM t +-- !query 390 schema +struct<> +-- !query 390 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 391 +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(10, 0)) FROM t +-- !query 391 schema +struct<> +-- !query 391 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 392 +SELECT cast('2017-12-11 09:30:00' as date) % cast(1 as decimal(20, 0)) FROM t +-- !query 392 schema +struct<> +-- !query 392 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) % CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 393 +SELECT cast(1 as decimal(3, 0)) % cast(1 as tinyint) FROM t +-- !query 393 schema +struct<(CAST(1 AS DECIMAL(3,0)) % CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):decimal(3,0)> +-- !query 393 output +0 + + +-- !query 394 +SELECT cast(1 as decimal(5, 0)) % cast(1 as tinyint) FROM t +-- !query 394 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):decimal(3,0)> +-- !query 394 output +0 + + +-- !query 395 +SELECT cast(1 as decimal(10, 0)) % cast(1 as tinyint) FROM t +-- !query 395 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 395 output +0 + + +-- !query 396 +SELECT cast(1 as decimal(20, 0)) % cast(1 as tinyint) FROM t +-- !query 396 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):decimal(3,0)> +-- !query 396 output +0 + + +-- !query 397 +SELECT cast(1 as decimal(3, 0)) % cast(1 as smallint) FROM t +-- !query 397 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):decimal(3,0)> +-- !query 397 output +0 + + +-- !query 398 +SELECT cast(1 as decimal(5, 0)) % cast(1 as smallint) FROM t +-- !query 398 schema +struct<(CAST(1 AS DECIMAL(5,0)) % CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):decimal(5,0)> +-- !query 398 output +0 + + +-- !query 399 +SELECT cast(1 as decimal(10, 0)) % cast(1 as smallint) FROM t +-- !query 399 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 399 output +0 + + +-- !query 400 +SELECT cast(1 as decimal(20, 0)) % cast(1 as smallint) FROM t +-- !query 400 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):decimal(5,0)> +-- !query 400 output +0 + + +-- !query 401 +SELECT cast(1 as decimal(3, 0)) % cast(1 as int) FROM t +-- !query 401 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 401 output +0 + + +-- !query 402 +SELECT cast(1 as decimal(5, 0)) % cast(1 as int) FROM t +-- !query 402 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 402 output +0 + + +-- !query 403 +SELECT cast(1 as decimal(10, 0)) % cast(1 as int) FROM t +-- !query 403 schema +struct<(CAST(1 AS DECIMAL(10,0)) % CAST(CAST(1 AS INT) AS DECIMAL(10,0))):decimal(10,0)> +-- !query 403 output +0 + + +-- !query 404 +SELECT cast(1 as decimal(20, 0)) % cast(1 as int) FROM t +-- !query 404 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 404 output +0 + + +-- !query 405 +SELECT cast(1 as decimal(3, 0)) % cast(1 as bigint) FROM t +-- !query 405 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(3,0)> +-- !query 405 output +0 + + +-- !query 406 +SELECT cast(1 as decimal(5, 0)) % cast(1 as bigint) FROM t +-- !query 406 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(5,0)> +-- !query 406 output +0 + + +-- !query 407 +SELECT cast(1 as decimal(10, 0)) % cast(1 as bigint) FROM t +-- !query 407 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) % CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 407 output +0 + + +-- !query 408 +SELECT cast(1 as decimal(20, 0)) % cast(1 as bigint) FROM t +-- !query 408 schema +struct<(CAST(1 AS DECIMAL(20,0)) % CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(20,0)> +-- !query 408 output +0 + + +-- !query 409 +SELECT cast(1 as decimal(3, 0)) % cast(1 as float) FROM t +-- !query 409 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 409 output +0.0 + + +-- !query 410 +SELECT cast(1 as decimal(5, 0)) % cast(1 as float) FROM t +-- !query 410 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 410 output +0.0 + + +-- !query 411 +SELECT cast(1 as decimal(10, 0)) % cast(1 as float) FROM t +-- !query 411 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 411 output +0.0 + + +-- !query 412 +SELECT cast(1 as decimal(20, 0)) % cast(1 as float) FROM t +-- !query 412 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(CAST(1 AS FLOAT) AS DOUBLE)):double> +-- !query 412 output +0.0 + + +-- !query 413 +SELECT cast(1 as decimal(3, 0)) % cast(1 as double) FROM t +-- !query 413 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 413 output +0.0 + + +-- !query 414 +SELECT cast(1 as decimal(5, 0)) % cast(1 as double) FROM t +-- !query 414 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 414 output +0.0 + + +-- !query 415 +SELECT cast(1 as decimal(10, 0)) % cast(1 as double) FROM t +-- !query 415 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 415 output +0.0 + + +-- !query 416 +SELECT cast(1 as decimal(20, 0)) % cast(1 as double) FROM t +-- !query 416 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(1 AS DOUBLE)):double> +-- !query 416 output +0.0 + + +-- !query 417 +SELECT cast(1 as decimal(3, 0)) % cast(1 as decimal(10, 0)) FROM t +-- !query 417 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(3,0)> +-- !query 417 output +0 + + +-- !query 418 +SELECT cast(1 as decimal(5, 0)) % cast(1 as decimal(10, 0)) FROM t +-- !query 418 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):decimal(5,0)> +-- !query 418 output +0 + + +-- !query 419 +SELECT cast(1 as decimal(10, 0)) % cast(1 as decimal(10, 0)) FROM t +-- !query 419 schema +struct<(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS DECIMAL(10,0))):decimal(10,0)> +-- !query 419 output +0 + + +-- !query 420 +SELECT cast(1 as decimal(20, 0)) % cast(1 as decimal(10, 0)) FROM t +-- !query 420 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) % CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):decimal(10,0)> +-- !query 420 output +0 + + +-- !query 421 +SELECT cast(1 as decimal(3, 0)) % cast(1 as string) FROM t +-- !query 421 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 421 output +0.0 + + +-- !query 422 +SELECT cast(1 as decimal(5, 0)) % cast(1 as string) FROM t +-- !query 422 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 422 output +0.0 + + +-- !query 423 +SELECT cast(1 as decimal(10, 0)) % cast(1 as string) FROM t +-- !query 423 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 423 output +0.0 + + +-- !query 424 +SELECT cast(1 as decimal(20, 0)) % cast(1 as string) FROM t +-- !query 424 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) % CAST(CAST(1 AS STRING) AS DOUBLE)):double> +-- !query 424 output +0.0 + + +-- !query 425 +SELECT cast(1 as decimal(3, 0)) % cast('1' as binary) FROM t +-- !query 425 schema +struct<> +-- !query 425 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 426 +SELECT cast(1 as decimal(5, 0)) % cast('1' as binary) FROM t +-- !query 426 schema +struct<> +-- !query 426 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 427 +SELECT cast(1 as decimal(10, 0)) % cast('1' as binary) FROM t +-- !query 427 schema +struct<> +-- !query 427 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 428 +SELECT cast(1 as decimal(20, 0)) % cast('1' as binary) FROM t +-- !query 428 schema +struct<> +-- !query 428 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 429 +SELECT cast(1 as decimal(3, 0)) % cast(1 as boolean) FROM t +-- !query 429 schema +struct<> +-- !query 429 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 430 +SELECT cast(1 as decimal(5, 0)) % cast(1 as boolean) FROM t +-- !query 430 schema +struct<> +-- !query 430 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 431 +SELECT cast(1 as decimal(10, 0)) % cast(1 as boolean) FROM t +-- !query 431 schema +struct<> +-- !query 431 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 432 +SELECT cast(1 as decimal(20, 0)) % cast(1 as boolean) FROM t +-- !query 432 schema +struct<> +-- !query 432 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 433 +SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 433 schema +struct<> +-- !query 433 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 434 +SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 434 schema +struct<> +-- !query 434 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 435 +SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 435 schema +struct<> +-- !query 435 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 436 +SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 436 schema +struct<> +-- !query 436 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 437 +SELECT cast(1 as decimal(3, 0)) % cast('2017-12-11 09:30:00' as date) FROM t +-- !query 437 schema +struct<> +-- !query 437 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 438 +SELECT cast(1 as decimal(5, 0)) % cast('2017-12-11 09:30:00' as date) FROM t +-- !query 438 schema +struct<> +-- !query 438 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 439 +SELECT cast(1 as decimal(10, 0)) % cast('2017-12-11 09:30:00' as date) FROM t +-- !query 439 schema +struct<> +-- !query 439 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 440 +SELECT cast(1 as decimal(20, 0)) % cast('2017-12-11 09:30:00' as date) FROM t +-- !query 440 schema +struct<> +-- !query 440 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) % CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 441 +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(3, 0))) FROM t +-- !query 441 schema +struct +-- !query 441 output +0 + + +-- !query 442 +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(5, 0))) FROM t +-- !query 442 schema +struct +-- !query 442 output +0 + + +-- !query 443 +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(10, 0))) FROM t +-- !query 443 schema +struct +-- !query 443 output +0 + + +-- !query 444 +SELECT pmod(cast(1 as tinyint), cast(1 as decimal(20, 0))) FROM t +-- !query 444 schema +struct +-- !query 444 output +0 + + +-- !query 445 +SELECT pmod(cast(1 as smallint), cast(1 as decimal(3, 0))) FROM t +-- !query 445 schema +struct +-- !query 445 output +0 + + +-- !query 446 +SELECT pmod(cast(1 as smallint), cast(1 as decimal(5, 0))) FROM t +-- !query 446 schema +struct +-- !query 446 output +0 + + +-- !query 447 +SELECT pmod(cast(1 as smallint), cast(1 as decimal(10, 0))) FROM t +-- !query 447 schema +struct +-- !query 447 output +0 + + +-- !query 448 +SELECT pmod(cast(1 as smallint), cast(1 as decimal(20, 0))) FROM t +-- !query 448 schema +struct +-- !query 448 output +0 + + +-- !query 449 +SELECT pmod(cast(1 as int), cast(1 as decimal(3, 0))) FROM t +-- !query 449 schema +struct +-- !query 449 output +0 + + +-- !query 450 +SELECT pmod(cast(1 as int), cast(1 as decimal(5, 0))) FROM t +-- !query 450 schema +struct +-- !query 450 output +0 + + +-- !query 451 +SELECT pmod(cast(1 as int), cast(1 as decimal(10, 0))) FROM t +-- !query 451 schema +struct +-- !query 451 output +0 + + +-- !query 452 +SELECT pmod(cast(1 as int), cast(1 as decimal(20, 0))) FROM t +-- !query 452 schema +struct +-- !query 452 output +0 + + +-- !query 453 +SELECT pmod(cast(1 as bigint), cast(1 as decimal(3, 0))) FROM t +-- !query 453 schema +struct +-- !query 453 output +0 + + +-- !query 454 +SELECT pmod(cast(1 as bigint), cast(1 as decimal(5, 0))) FROM t +-- !query 454 schema +struct +-- !query 454 output +0 + + +-- !query 455 +SELECT pmod(cast(1 as bigint), cast(1 as decimal(10, 0))) FROM t +-- !query 455 schema +struct +-- !query 455 output +0 + + +-- !query 456 +SELECT pmod(cast(1 as bigint), cast(1 as decimal(20, 0))) FROM t +-- !query 456 schema +struct +-- !query 456 output +0 + + +-- !query 457 +SELECT pmod(cast(1 as float), cast(1 as decimal(3, 0))) FROM t +-- !query 457 schema +struct +-- !query 457 output +0.0 + + +-- !query 458 +SELECT pmod(cast(1 as float), cast(1 as decimal(5, 0))) FROM t +-- !query 458 schema +struct +-- !query 458 output +0.0 + + +-- !query 459 +SELECT pmod(cast(1 as float), cast(1 as decimal(10, 0))) FROM t +-- !query 459 schema +struct +-- !query 459 output +0.0 + + +-- !query 460 +SELECT pmod(cast(1 as float), cast(1 as decimal(20, 0))) FROM t +-- !query 460 schema +struct +-- !query 460 output +0.0 + + +-- !query 461 +SELECT pmod(cast(1 as double), cast(1 as decimal(3, 0))) FROM t +-- !query 461 schema +struct +-- !query 461 output +0.0 + + +-- !query 462 +SELECT pmod(cast(1 as double), cast(1 as decimal(5, 0))) FROM t +-- !query 462 schema +struct +-- !query 462 output +0.0 + + +-- !query 463 +SELECT pmod(cast(1 as double), cast(1 as decimal(10, 0))) FROM t +-- !query 463 schema +struct +-- !query 463 output +0.0 + + +-- !query 464 +SELECT pmod(cast(1 as double), cast(1 as decimal(20, 0))) FROM t +-- !query 464 schema +struct +-- !query 464 output +0.0 + + +-- !query 465 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(3, 0))) FROM t +-- !query 465 schema +struct +-- !query 465 output +0 + + +-- !query 466 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(5, 0))) FROM t +-- !query 466 schema +struct +-- !query 466 output +0 + + +-- !query 467 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t +-- !query 467 schema +struct +-- !query 467 output +0 + + +-- !query 468 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(20, 0))) FROM t +-- !query 468 schema +struct +-- !query 468 output +0 + + +-- !query 469 +SELECT pmod(cast('1' as binary), cast(1 as decimal(3, 0))) FROM t +-- !query 469 schema +struct<> +-- !query 469 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 470 +SELECT pmod(cast('1' as binary), cast(1 as decimal(5, 0))) FROM t +-- !query 470 schema +struct<> +-- !query 470 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 471 +SELECT pmod(cast('1' as binary), cast(1 as decimal(10, 0))) FROM t +-- !query 471 schema +struct<> +-- !query 471 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 472 +SELECT pmod(cast('1' as binary), cast(1 as decimal(20, 0))) FROM t +-- !query 472 schema +struct<> +-- !query 472 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('1' AS BINARY), CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 473 +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(3, 0))) FROM t +-- !query 473 schema +struct<> +-- !query 473 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 474 +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(5, 0))) FROM t +-- !query 474 schema +struct<> +-- !query 474 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 475 +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(10, 0))) FROM t +-- !query 475 schema +struct<> +-- !query 475 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 476 +SELECT pmod(cast('2017-12-11 09:30:00.0' as timestamp), cast(1 as decimal(20, 0))) FROM t +-- !query 476 schema +struct<> +-- !query 476 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 477 +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(3, 0))) FROM t +-- !query 477 schema +struct<> +-- !query 477 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 478 +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(5, 0))) FROM t +-- !query 478 schema +struct<> +-- !query 478 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 479 +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(10, 0))) FROM t +-- !query 479 schema +struct<> +-- !query 479 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 480 +SELECT pmod(cast('2017-12-11 09:30:00' as date), cast(1 as decimal(20, 0))) FROM t +-- !query 480 schema +struct<> +-- !query 480 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in 'pmod(CAST('2017-12-11 09:30:00' AS DATE), CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 481 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as tinyint)) FROM t +-- !query 481 schema +struct +-- !query 481 output +0 + + +-- !query 482 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as tinyint)) FROM t +-- !query 482 schema +struct +-- !query 482 output +0 + + +-- !query 483 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as tinyint)) FROM t +-- !query 483 schema +struct +-- !query 483 output +0 + + +-- !query 484 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as tinyint)) FROM t +-- !query 484 schema +struct +-- !query 484 output +0 + + +-- !query 485 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as smallint)) FROM t +-- !query 485 schema +struct +-- !query 485 output +0 + + +-- !query 486 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as smallint)) FROM t +-- !query 486 schema +struct +-- !query 486 output +0 + + +-- !query 487 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as smallint)) FROM t +-- !query 487 schema +struct +-- !query 487 output +0 + + +-- !query 488 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as smallint)) FROM t +-- !query 488 schema +struct +-- !query 488 output +0 + + +-- !query 489 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as int)) FROM t +-- !query 489 schema +struct +-- !query 489 output +0 + + +-- !query 490 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as int)) FROM t +-- !query 490 schema +struct +-- !query 490 output +0 + + +-- !query 491 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as int)) FROM t +-- !query 491 schema +struct +-- !query 491 output +0 + + +-- !query 492 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as int)) FROM t +-- !query 492 schema +struct +-- !query 492 output +0 + + +-- !query 493 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as bigint)) FROM t +-- !query 493 schema +struct +-- !query 493 output +0 + + +-- !query 494 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as bigint)) FROM t +-- !query 494 schema +struct +-- !query 494 output +0 + + +-- !query 495 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as bigint)) FROM t +-- !query 495 schema +struct +-- !query 495 output +0 + + +-- !query 496 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as bigint)) FROM t +-- !query 496 schema +struct +-- !query 496 output +0 + + +-- !query 497 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as float)) FROM t +-- !query 497 schema +struct +-- !query 497 output +0.0 + + +-- !query 498 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as float)) FROM t +-- !query 498 schema +struct +-- !query 498 output +0.0 + + +-- !query 499 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as float)) FROM t +-- !query 499 schema +struct +-- !query 499 output +0.0 + + +-- !query 500 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as float)) FROM t +-- !query 500 schema +struct +-- !query 500 output +0.0 + + +-- !query 501 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as double)) FROM t +-- !query 501 schema +struct +-- !query 501 output +0.0 + + +-- !query 502 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as double)) FROM t +-- !query 502 schema +struct +-- !query 502 output +0.0 + + +-- !query 503 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as double)) FROM t +-- !query 503 schema +struct +-- !query 503 output +0.0 + + +-- !query 504 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as double)) FROM t +-- !query 504 schema +struct +-- !query 504 output +0.0 + + +-- !query 505 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as decimal(10, 0))) FROM t +-- !query 505 schema +struct +-- !query 505 output +0 + + +-- !query 506 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as decimal(10, 0))) FROM t +-- !query 506 schema +struct +-- !query 506 output +0 + + +-- !query 507 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as decimal(10, 0))) FROM t +-- !query 507 schema +struct +-- !query 507 output +0 + + +-- !query 508 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as decimal(10, 0))) FROM t +-- !query 508 schema +struct +-- !query 508 output +0 + + +-- !query 509 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as string)) FROM t +-- !query 509 schema +struct +-- !query 509 output +0.0 + + +-- !query 510 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as string)) FROM t +-- !query 510 schema +struct +-- !query 510 output +0.0 + + +-- !query 511 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as string)) FROM t +-- !query 511 schema +struct +-- !query 511 output +0.0 + + +-- !query 512 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as string)) FROM t +-- !query 512 schema +struct +-- !query 512 output +0.0 + + +-- !query 513 +SELECT pmod(cast(1 as decimal(3, 0)) , cast('1' as binary)) FROM t +-- !query 513 schema +struct<> +-- !query 513 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 514 +SELECT pmod(cast(1 as decimal(5, 0)) , cast('1' as binary)) FROM t +-- !query 514 schema +struct<> +-- !query 514 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 515 +SELECT pmod(cast(1 as decimal(10, 0)), cast('1' as binary)) FROM t +-- !query 515 schema +struct<> +-- !query 515 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 516 +SELECT pmod(cast(1 as decimal(20, 0)), cast('1' as binary)) FROM t +-- !query 516 schema +struct<> +-- !query 516 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('1' AS BINARY))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 517 +SELECT pmod(cast(1 as decimal(3, 0)) , cast(1 as boolean)) FROM t +-- !query 517 schema +struct<> +-- !query 517 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 518 +SELECT pmod(cast(1 as decimal(5, 0)) , cast(1 as boolean)) FROM t +-- !query 518 schema +struct<> +-- !query 518 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 519 +SELECT pmod(cast(1 as decimal(10, 0)), cast(1 as boolean)) FROM t +-- !query 519 schema +struct<> +-- !query 519 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 520 +SELECT pmod(cast(1 as decimal(20, 0)), cast(1 as boolean)) FROM t +-- !query 520 schema +struct<> +-- !query 520 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 521 +SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 521 schema +struct<> +-- !query 521 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 522 +SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 522 schema +struct<> +-- !query 522 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 523 +SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 523 schema +struct<> +-- !query 523 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 524 +SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 524 schema +struct<> +-- !query 524 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 525 +SELECT pmod(cast(1 as decimal(3, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 525 schema +struct<> +-- !query 525 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(3,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 526 +SELECT pmod(cast(1 as decimal(5, 0)) , cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 526 schema +struct<> +-- !query 526 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(5,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 527 +SELECT pmod(cast(1 as decimal(10, 0)), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 527 schema +struct<> +-- !query 527 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 528 +SELECT pmod(cast(1 as decimal(20, 0)), cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 528 schema +struct<> +-- !query 528 output +org.apache.spark.sql.AnalysisException +cannot resolve 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in 'pmod(CAST(1 AS DECIMAL(20,0)), CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 529 +SELECT cast(1 as tinyint) = cast(1 as decimal(3, 0)) FROM t +-- !query 529 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) = CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 529 output +true + + +-- !query 530 +SELECT cast(1 as tinyint) = cast(1 as decimal(5, 0)) FROM t +-- !query 530 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 530 output +true + + +-- !query 531 +SELECT cast(1 as tinyint) = cast(1 as decimal(10, 0)) FROM t +-- !query 531 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 531 output +true + + +-- !query 532 +SELECT cast(1 as tinyint) = cast(1 as decimal(20, 0)) FROM t +-- !query 532 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 532 output +true + + +-- !query 533 +SELECT cast(1 as smallint) = cast(1 as decimal(3, 0)) FROM t +-- !query 533 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 533 output +true + + +-- !query 534 +SELECT cast(1 as smallint) = cast(1 as decimal(5, 0)) FROM t +-- !query 534 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) = CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 534 output +true + + +-- !query 535 +SELECT cast(1 as smallint) = cast(1 as decimal(10, 0)) FROM t +-- !query 535 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 535 output +true + + +-- !query 536 +SELECT cast(1 as smallint) = cast(1 as decimal(20, 0)) FROM t +-- !query 536 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 536 output +true + + +-- !query 537 +SELECT cast(1 as int) = cast(1 as decimal(3, 0)) FROM t +-- !query 537 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 537 output +true + + +-- !query 538 +SELECT cast(1 as int) = cast(1 as decimal(5, 0)) FROM t +-- !query 538 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 538 output +true + + +-- !query 539 +SELECT cast(1 as int) = cast(1 as decimal(10, 0)) FROM t +-- !query 539 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 539 output +true + + +-- !query 540 +SELECT cast(1 as int) = cast(1 as decimal(20, 0)) FROM t +-- !query 540 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 540 output +true + + +-- !query 541 +SELECT cast(1 as bigint) = cast(1 as decimal(3, 0)) FROM t +-- !query 541 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 541 output +true + + +-- !query 542 +SELECT cast(1 as bigint) = cast(1 as decimal(5, 0)) FROM t +-- !query 542 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 542 output +true + + +-- !query 543 +SELECT cast(1 as bigint) = cast(1 as decimal(10, 0)) FROM t +-- !query 543 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 543 output +true + + +-- !query 544 +SELECT cast(1 as bigint) = cast(1 as decimal(20, 0)) FROM t +-- !query 544 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) = CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 544 output +true + + +-- !query 545 +SELECT cast(1 as float) = cast(1 as decimal(3, 0)) FROM t +-- !query 545 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 545 output +true + + +-- !query 546 +SELECT cast(1 as float) = cast(1 as decimal(5, 0)) FROM t +-- !query 546 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 546 output +true + + +-- !query 547 +SELECT cast(1 as float) = cast(1 as decimal(10, 0)) FROM t +-- !query 547 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 547 output +true + + +-- !query 548 +SELECT cast(1 as float) = cast(1 as decimal(20, 0)) FROM t +-- !query 548 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 548 output +true + + +-- !query 549 +SELECT cast(1 as double) = cast(1 as decimal(3, 0)) FROM t +-- !query 549 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 549 output +true + + +-- !query 550 +SELECT cast(1 as double) = cast(1 as decimal(5, 0)) FROM t +-- !query 550 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 550 output +true + + +-- !query 551 +SELECT cast(1 as double) = cast(1 as decimal(10, 0)) FROM t +-- !query 551 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 551 output +true + + +-- !query 552 +SELECT cast(1 as double) = cast(1 as decimal(20, 0)) FROM t +-- !query 552 schema +struct<(CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 552 output +true + + +-- !query 553 +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(3, 0)) FROM t +-- !query 553 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 553 output +true + + +-- !query 554 +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(5, 0)) FROM t +-- !query 554 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 554 output +true + + +-- !query 555 +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t +-- !query 555 schema +struct<(CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 555 output +true + + +-- !query 556 +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(20, 0)) FROM t +-- !query 556 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 556 output +true + + +-- !query 557 +SELECT cast('1' as binary) = cast(1 as decimal(3, 0)) FROM t +-- !query 557 schema +struct<> +-- !query 557 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 558 +SELECT cast('1' as binary) = cast(1 as decimal(5, 0)) FROM t +-- !query 558 schema +struct<> +-- !query 558 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 559 +SELECT cast('1' as binary) = cast(1 as decimal(10, 0)) FROM t +-- !query 559 schema +struct<> +-- !query 559 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 560 +SELECT cast('1' as binary) = cast(1 as decimal(20, 0)) FROM t +-- !query 560 schema +struct<> +-- !query 560 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 561 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(3, 0)) FROM t +-- !query 561 schema +struct<> +-- !query 561 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 562 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(5, 0)) FROM t +-- !query 562 schema +struct<> +-- !query 562 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 563 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(10, 0)) FROM t +-- !query 563 schema +struct<> +-- !query 563 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 564 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) = cast(1 as decimal(20, 0)) FROM t +-- !query 564 schema +struct<> +-- !query 564 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 565 +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(3, 0)) FROM t +-- !query 565 schema +struct<> +-- !query 565 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 566 +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(5, 0)) FROM t +-- !query 566 schema +struct<> +-- !query 566 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 567 +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(10, 0)) FROM t +-- !query 567 schema +struct<> +-- !query 567 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 568 +SELECT cast('2017-12-11 09:30:00' as date) = cast(1 as decimal(20, 0)) FROM t +-- !query 568 schema +struct<> +-- !query 568 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 569 +SELECT cast(1 as decimal(3, 0)) = cast(1 as tinyint) FROM t +-- !query 569 schema +struct<(CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 569 output +true + + +-- !query 570 +SELECT cast(1 as decimal(5, 0)) = cast(1 as tinyint) FROM t +-- !query 570 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 570 output +true + + +-- !query 571 +SELECT cast(1 as decimal(10, 0)) = cast(1 as tinyint) FROM t +-- !query 571 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 571 output +true + + +-- !query 572 +SELECT cast(1 as decimal(20, 0)) = cast(1 as tinyint) FROM t +-- !query 572 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 572 output +true + + +-- !query 573 +SELECT cast(1 as decimal(3, 0)) = cast(1 as smallint) FROM t +-- !query 573 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 573 output +true + + +-- !query 574 +SELECT cast(1 as decimal(5, 0)) = cast(1 as smallint) FROM t +-- !query 574 schema +struct<(CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 574 output +true + + +-- !query 575 +SELECT cast(1 as decimal(10, 0)) = cast(1 as smallint) FROM t +-- !query 575 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 575 output +true + + +-- !query 576 +SELECT cast(1 as decimal(20, 0)) = cast(1 as smallint) FROM t +-- !query 576 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 576 output +true + + +-- !query 577 +SELECT cast(1 as decimal(3, 0)) = cast(1 as int) FROM t +-- !query 577 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 577 output +true + + +-- !query 578 +SELECT cast(1 as decimal(5, 0)) = cast(1 as int) FROM t +-- !query 578 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 578 output +true + + +-- !query 579 +SELECT cast(1 as decimal(10, 0)) = cast(1 as int) FROM t +-- !query 579 schema +struct<(CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 579 output +true + + +-- !query 580 +SELECT cast(1 as decimal(20, 0)) = cast(1 as int) FROM t +-- !query 580 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 580 output +true + + +-- !query 581 +SELECT cast(1 as decimal(3, 0)) = cast(1 as bigint) FROM t +-- !query 581 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 581 output +true + + +-- !query 582 +SELECT cast(1 as decimal(5, 0)) = cast(1 as bigint) FROM t +-- !query 582 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 582 output +true + + +-- !query 583 +SELECT cast(1 as decimal(10, 0)) = cast(1 as bigint) FROM t +-- !query 583 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 583 output +true + + +-- !query 584 +SELECT cast(1 as decimal(20, 0)) = cast(1 as bigint) FROM t +-- !query 584 schema +struct<(CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 584 output +true + + +-- !query 585 +SELECT cast(1 as decimal(3, 0)) = cast(1 as float) FROM t +-- !query 585 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 585 output +true + + +-- !query 586 +SELECT cast(1 as decimal(5, 0)) = cast(1 as float) FROM t +-- !query 586 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 586 output +true + + +-- !query 587 +SELECT cast(1 as decimal(10, 0)) = cast(1 as float) FROM t +-- !query 587 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 587 output +true + + +-- !query 588 +SELECT cast(1 as decimal(20, 0)) = cast(1 as float) FROM t +-- !query 588 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 588 output +true + + +-- !query 589 +SELECT cast(1 as decimal(3, 0)) = cast(1 as double) FROM t +-- !query 589 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 589 output +true + + +-- !query 590 +SELECT cast(1 as decimal(5, 0)) = cast(1 as double) FROM t +-- !query 590 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 590 output +true + + +-- !query 591 +SELECT cast(1 as decimal(10, 0)) = cast(1 as double) FROM t +-- !query 591 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 591 output +true + + +-- !query 592 +SELECT cast(1 as decimal(20, 0)) = cast(1 as double) FROM t +-- !query 592 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(1 AS DOUBLE)):boolean> +-- !query 592 output +true + + +-- !query 593 +SELECT cast(1 as decimal(3, 0)) = cast(1 as decimal(10, 0)) FROM t +-- !query 593 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 593 output +true + + +-- !query 594 +SELECT cast(1 as decimal(5, 0)) = cast(1 as decimal(10, 0)) FROM t +-- !query 594 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 594 output +true + + +-- !query 595 +SELECT cast(1 as decimal(10, 0)) = cast(1 as decimal(10, 0)) FROM t +-- !query 595 schema +struct<(CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 595 output +true + + +-- !query 596 +SELECT cast(1 as decimal(20, 0)) = cast(1 as decimal(10, 0)) FROM t +-- !query 596 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 596 output +true + + +-- !query 597 +SELECT cast(1 as decimal(3, 0)) = cast(1 as string) FROM t +-- !query 597 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 597 output +true + + +-- !query 598 +SELECT cast(1 as decimal(5, 0)) = cast(1 as string) FROM t +-- !query 598 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 598 output +true + + +-- !query 599 +SELECT cast(1 as decimal(10, 0)) = cast(1 as string) FROM t +-- !query 599 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 599 output +true + + +-- !query 600 +SELECT cast(1 as decimal(20, 0)) = cast(1 as string) FROM t +-- !query 600 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 600 output +true + + +-- !query 601 +SELECT cast(1 as decimal(3, 0)) = cast('1' as binary) FROM t +-- !query 601 schema +struct<> +-- !query 601 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 602 +SELECT cast(1 as decimal(5, 0)) = cast('1' as binary) FROM t +-- !query 602 schema +struct<> +-- !query 602 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 603 +SELECT cast(1 as decimal(10, 0)) = cast('1' as binary) FROM t +-- !query 603 schema +struct<> +-- !query 603 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 604 +SELECT cast(1 as decimal(20, 0)) = cast('1' as binary) FROM t +-- !query 604 schema +struct<> +-- !query 604 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 605 +SELECT cast(1 as decimal(3, 0)) = cast(1 as boolean) FROM t +-- !query 605 schema +struct<(CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0))):boolean> +-- !query 605 output +true + + +-- !query 606 +SELECT cast(1 as decimal(5, 0)) = cast(1 as boolean) FROM t +-- !query 606 schema +struct<(CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0))):boolean> +-- !query 606 output +true + + +-- !query 607 +SELECT cast(1 as decimal(10, 0)) = cast(1 as boolean) FROM t +-- !query 607 schema +struct<(CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0))):boolean> +-- !query 607 output +true + + +-- !query 608 +SELECT cast(1 as decimal(20, 0)) = cast(1 as boolean) FROM t +-- !query 608 schema +struct<(CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0))):boolean> +-- !query 608 output +true + + +-- !query 609 +SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 609 schema +struct<> +-- !query 609 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 610 +SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 610 schema +struct<> +-- !query 610 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 611 +SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 611 schema +struct<> +-- !query 611 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 612 +SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 612 schema +struct<> +-- !query 612 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 613 +SELECT cast(1 as decimal(3, 0)) = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 613 schema +struct<> +-- !query 613 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 614 +SELECT cast(1 as decimal(5, 0)) = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 614 schema +struct<> +-- !query 614 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 615 +SELECT cast(1 as decimal(10, 0)) = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 615 schema +struct<> +-- !query 615 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 616 +SELECT cast(1 as decimal(20, 0)) = cast('2017-12-11 09:30:00' as date) FROM t +-- !query 616 schema +struct<> +-- !query 616 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 617 +SELECT cast(1 as tinyint) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 617 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) <=> CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 617 output +true + + +-- !query 618 +SELECT cast(1 as tinyint) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 618 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 618 output +true + + +-- !query 619 +SELECT cast(1 as tinyint) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 619 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 619 output +true + + +-- !query 620 +SELECT cast(1 as tinyint) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 620 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 620 output +true + + +-- !query 621 +SELECT cast(1 as smallint) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 621 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 621 output +true + + +-- !query 622 +SELECT cast(1 as smallint) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 622 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) <=> CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 622 output +true + + +-- !query 623 +SELECT cast(1 as smallint) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 623 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 623 output +true + + +-- !query 624 +SELECT cast(1 as smallint) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 624 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 624 output +true + + +-- !query 625 +SELECT cast(1 as int) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 625 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 625 output +true + + +-- !query 626 +SELECT cast(1 as int) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 626 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 626 output +true + + +-- !query 627 +SELECT cast(1 as int) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 627 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 627 output +true + + +-- !query 628 +SELECT cast(1 as int) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 628 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 628 output +true + + +-- !query 629 +SELECT cast(1 as bigint) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 629 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 629 output +true + + +-- !query 630 +SELECT cast(1 as bigint) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 630 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 630 output +true + + +-- !query 631 +SELECT cast(1 as bigint) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 631 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 631 output +true + + +-- !query 632 +SELECT cast(1 as bigint) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 632 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) <=> CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 632 output +true + + +-- !query 633 +SELECT cast(1 as float) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 633 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 633 output +true + + +-- !query 634 +SELECT cast(1 as float) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 634 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 634 output +true + + +-- !query 635 +SELECT cast(1 as float) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 635 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 635 output +true + + +-- !query 636 +SELECT cast(1 as float) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 636 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 636 output +true + + +-- !query 637 +SELECT cast(1 as double) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 637 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 637 output +true + + +-- !query 638 +SELECT cast(1 as double) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 638 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 638 output +true + + +-- !query 639 +SELECT cast(1 as double) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 639 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 639 output +true + + +-- !query 640 +SELECT cast(1 as double) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 640 schema +struct<(CAST(1 AS DOUBLE) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 640 output +true + + +-- !query 641 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 641 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 641 output +true + + +-- !query 642 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 642 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 642 output +true + + +-- !query 643 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 643 schema +struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 643 output +true + + +-- !query 644 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 644 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 644 output +true + + +-- !query 645 +SELECT cast('1' as binary) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 645 schema +struct<> +-- !query 645 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 646 +SELECT cast('1' as binary) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 646 schema +struct<> +-- !query 646 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 647 +SELECT cast('1' as binary) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 647 schema +struct<> +-- !query 647 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 648 +SELECT cast('1' as binary) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 648 schema +struct<> +-- !query 648 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <=> CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 649 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 649 schema +struct<> +-- !query 649 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 650 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 650 schema +struct<> +-- !query 650 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 651 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 651 schema +struct<> +-- !query 651 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 652 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 652 schema +struct<> +-- !query 652 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <=> CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 653 +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(3, 0)) FROM t +-- !query 653 schema +struct<> +-- !query 653 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 654 +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(5, 0)) FROM t +-- !query 654 schema +struct<> +-- !query 654 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 655 +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 655 schema +struct<> +-- !query 655 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 656 +SELECT cast('2017-12-11 09:30:00' as date) <=> cast(1 as decimal(20, 0)) FROM t +-- !query 656 schema +struct<> +-- !query 656 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <=> CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 657 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as tinyint) FROM t +-- !query 657 schema +struct<(CAST(1 AS DECIMAL(3,0)) <=> CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 657 output +true + + +-- !query 658 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as tinyint) FROM t +-- !query 658 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 658 output +true + + +-- !query 659 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as tinyint) FROM t +-- !query 659 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 659 output +true + + +-- !query 660 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as tinyint) FROM t +-- !query 660 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 660 output +true + + +-- !query 661 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as smallint) FROM t +-- !query 661 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 661 output +true + + +-- !query 662 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as smallint) FROM t +-- !query 662 schema +struct<(CAST(1 AS DECIMAL(5,0)) <=> CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 662 output +true + + +-- !query 663 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as smallint) FROM t +-- !query 663 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 663 output +true + + +-- !query 664 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as smallint) FROM t +-- !query 664 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 664 output +true + + +-- !query 665 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as int) FROM t +-- !query 665 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 665 output +true + + +-- !query 666 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as int) FROM t +-- !query 666 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 666 output +true + + +-- !query 667 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as int) FROM t +-- !query 667 schema +struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 667 output +true + + +-- !query 668 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as int) FROM t +-- !query 668 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 668 output +true + + +-- !query 669 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as bigint) FROM t +-- !query 669 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 669 output +true + + +-- !query 670 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as bigint) FROM t +-- !query 670 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 670 output +true + + +-- !query 671 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as bigint) FROM t +-- !query 671 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <=> CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 671 output +true + + +-- !query 672 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as bigint) FROM t +-- !query 672 schema +struct<(CAST(1 AS DECIMAL(20,0)) <=> CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 672 output +true + + +-- !query 673 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as float) FROM t +-- !query 673 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 673 output +true + + +-- !query 674 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as float) FROM t +-- !query 674 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 674 output +true + + +-- !query 675 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as float) FROM t +-- !query 675 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 675 output +true + + +-- !query 676 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as float) FROM t +-- !query 676 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 676 output +true + + +-- !query 677 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as double) FROM t +-- !query 677 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 677 output +true + + +-- !query 678 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as double) FROM t +-- !query 678 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 678 output +true + + +-- !query 679 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as double) FROM t +-- !query 679 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 679 output +true + + +-- !query 680 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as double) FROM t +-- !query 680 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(1 AS DOUBLE)):boolean> +-- !query 680 output +true + + +-- !query 681 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 681 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 681 output +true + + +-- !query 682 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 682 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 682 output +true + + +-- !query 683 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 683 schema +struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 683 output +true + + +-- !query 684 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as decimal(10, 0)) FROM t +-- !query 684 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <=> CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 684 output +true + + +-- !query 685 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as string) FROM t +-- !query 685 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 685 output +true + + +-- !query 686 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as string) FROM t +-- !query 686 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 686 output +true + + +-- !query 687 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as string) FROM t +-- !query 687 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 687 output +true + + +-- !query 688 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as string) FROM t +-- !query 688 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <=> CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 688 output +true + + +-- !query 689 +SELECT cast(1 as decimal(3, 0)) <=> cast('1' as binary) FROM t +-- !query 689 schema +struct<> +-- !query 689 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 690 +SELECT cast(1 as decimal(5, 0)) <=> cast('1' as binary) FROM t +-- !query 690 schema +struct<> +-- !query 690 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 691 +SELECT cast(1 as decimal(10, 0)) <=> cast('1' as binary) FROM t +-- !query 691 schema +struct<> +-- !query 691 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 692 +SELECT cast(1 as decimal(20, 0)) <=> cast('1' as binary) FROM t +-- !query 692 schema +struct<> +-- !query 692 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 693 +SELECT cast(1 as decimal(3, 0)) <=> cast(1 as boolean) FROM t +-- !query 693 schema +struct<(CAST(1 AS DECIMAL(3,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0))):boolean> +-- !query 693 output +true + + +-- !query 694 +SELECT cast(1 as decimal(5, 0)) <=> cast(1 as boolean) FROM t +-- !query 694 schema +struct<(CAST(1 AS DECIMAL(5,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0))):boolean> +-- !query 694 output +true + + +-- !query 695 +SELECT cast(1 as decimal(10, 0)) <=> cast(1 as boolean) FROM t +-- !query 695 schema +struct<(CAST(1 AS DECIMAL(10,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0))):boolean> +-- !query 695 output +true + + +-- !query 696 +SELECT cast(1 as decimal(20, 0)) <=> cast(1 as boolean) FROM t +-- !query 696 schema +struct<(CAST(1 AS DECIMAL(20,0)) <=> CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0))):boolean> +-- !query 696 output +true + + +-- !query 697 +SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 697 schema +struct<> +-- !query 697 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 698 +SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 698 schema +struct<> +-- !query 698 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 699 +SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 699 schema +struct<> +-- !query 699 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 700 +SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 700 schema +struct<> +-- !query 700 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 701 +SELECT cast(1 as decimal(3, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 701 schema +struct<> +-- !query 701 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 702 +SELECT cast(1 as decimal(5, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 702 schema +struct<> +-- !query 702 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 703 +SELECT cast(1 as decimal(10, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 703 schema +struct<> +-- !query 703 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 704 +SELECT cast(1 as decimal(20, 0)) <=> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 704 schema +struct<> +-- !query 704 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <=> CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 705 +SELECT cast(1 as tinyint) < cast(1 as decimal(3, 0)) FROM t +-- !query 705 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) < CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 705 output +false + + +-- !query 706 +SELECT cast(1 as tinyint) < cast(1 as decimal(5, 0)) FROM t +-- !query 706 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 706 output +false + + +-- !query 707 +SELECT cast(1 as tinyint) < cast(1 as decimal(10, 0)) FROM t +-- !query 707 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 707 output +false + + +-- !query 708 +SELECT cast(1 as tinyint) < cast(1 as decimal(20, 0)) FROM t +-- !query 708 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 708 output +false + + +-- !query 709 +SELECT cast(1 as smallint) < cast(1 as decimal(3, 0)) FROM t +-- !query 709 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 709 output +false + + +-- !query 710 +SELECT cast(1 as smallint) < cast(1 as decimal(5, 0)) FROM t +-- !query 710 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) < CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 710 output +false + + +-- !query 711 +SELECT cast(1 as smallint) < cast(1 as decimal(10, 0)) FROM t +-- !query 711 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 711 output +false + + +-- !query 712 +SELECT cast(1 as smallint) < cast(1 as decimal(20, 0)) FROM t +-- !query 712 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 712 output +false + + +-- !query 713 +SELECT cast(1 as int) < cast(1 as decimal(3, 0)) FROM t +-- !query 713 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 713 output +false + + +-- !query 714 +SELECT cast(1 as int) < cast(1 as decimal(5, 0)) FROM t +-- !query 714 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 714 output +false + + +-- !query 715 +SELECT cast(1 as int) < cast(1 as decimal(10, 0)) FROM t +-- !query 715 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 715 output +false + + +-- !query 716 +SELECT cast(1 as int) < cast(1 as decimal(20, 0)) FROM t +-- !query 716 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 716 output +false + + +-- !query 717 +SELECT cast(1 as bigint) < cast(1 as decimal(3, 0)) FROM t +-- !query 717 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 717 output +false + + +-- !query 718 +SELECT cast(1 as bigint) < cast(1 as decimal(5, 0)) FROM t +-- !query 718 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 718 output +false + + +-- !query 719 +SELECT cast(1 as bigint) < cast(1 as decimal(10, 0)) FROM t +-- !query 719 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 719 output +false + + +-- !query 720 +SELECT cast(1 as bigint) < cast(1 as decimal(20, 0)) FROM t +-- !query 720 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) < CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 720 output +false + + +-- !query 721 +SELECT cast(1 as float) < cast(1 as decimal(3, 0)) FROM t +-- !query 721 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 721 output +false + + +-- !query 722 +SELECT cast(1 as float) < cast(1 as decimal(5, 0)) FROM t +-- !query 722 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 722 output +false + + +-- !query 723 +SELECT cast(1 as float) < cast(1 as decimal(10, 0)) FROM t +-- !query 723 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 723 output +false + + +-- !query 724 +SELECT cast(1 as float) < cast(1 as decimal(20, 0)) FROM t +-- !query 724 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) < CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 724 output +false + + +-- !query 725 +SELECT cast(1 as double) < cast(1 as decimal(3, 0)) FROM t +-- !query 725 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 725 output +false + + +-- !query 726 +SELECT cast(1 as double) < cast(1 as decimal(5, 0)) FROM t +-- !query 726 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 726 output +false + + +-- !query 727 +SELECT cast(1 as double) < cast(1 as decimal(10, 0)) FROM t +-- !query 727 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 727 output +false + + +-- !query 728 +SELECT cast(1 as double) < cast(1 as decimal(20, 0)) FROM t +-- !query 728 schema +struct<(CAST(1 AS DOUBLE) < CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 728 output +false + + +-- !query 729 +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(3, 0)) FROM t +-- !query 729 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 729 output +false + + +-- !query 730 +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(5, 0)) FROM t +-- !query 730 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 730 output +false + + +-- !query 731 +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t +-- !query 731 schema +struct<(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 731 output +false + + +-- !query 732 +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(20, 0)) FROM t +-- !query 732 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 732 output +false + + +-- !query 733 +SELECT cast('1' as binary) < cast(1 as decimal(3, 0)) FROM t +-- !query 733 schema +struct<> +-- !query 733 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 734 +SELECT cast('1' as binary) < cast(1 as decimal(5, 0)) FROM t +-- !query 734 schema +struct<> +-- !query 734 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 735 +SELECT cast('1' as binary) < cast(1 as decimal(10, 0)) FROM t +-- !query 735 schema +struct<> +-- !query 735 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 736 +SELECT cast('1' as binary) < cast(1 as decimal(20, 0)) FROM t +-- !query 736 schema +struct<> +-- !query 736 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) < CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 737 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(3, 0)) FROM t +-- !query 737 schema +struct<> +-- !query 737 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 738 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(5, 0)) FROM t +-- !query 738 schema +struct<> +-- !query 738 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 739 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(10, 0)) FROM t +-- !query 739 schema +struct<> +-- !query 739 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 740 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) < cast(1 as decimal(20, 0)) FROM t +-- !query 740 schema +struct<> +-- !query 740 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) < CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 741 +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(3, 0)) FROM t +-- !query 741 schema +struct<> +-- !query 741 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 742 +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(5, 0)) FROM t +-- !query 742 schema +struct<> +-- !query 742 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 743 +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(10, 0)) FROM t +-- !query 743 schema +struct<> +-- !query 743 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 744 +SELECT cast('2017-12-11 09:30:00' as date) < cast(1 as decimal(20, 0)) FROM t +-- !query 744 schema +struct<> +-- !query 744 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) < CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 745 +SELECT cast(1 as decimal(3, 0)) < cast(1 as tinyint) FROM t +-- !query 745 schema +struct<(CAST(1 AS DECIMAL(3,0)) < CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 745 output +false + + +-- !query 746 +SELECT cast(1 as decimal(5, 0)) < cast(1 as tinyint) FROM t +-- !query 746 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 746 output +false + + +-- !query 747 +SELECT cast(1 as decimal(10, 0)) < cast(1 as tinyint) FROM t +-- !query 747 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 747 output +false + + +-- !query 748 +SELECT cast(1 as decimal(20, 0)) < cast(1 as tinyint) FROM t +-- !query 748 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 748 output +false + + +-- !query 749 +SELECT cast(1 as decimal(3, 0)) < cast(1 as smallint) FROM t +-- !query 749 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 749 output +false + + +-- !query 750 +SELECT cast(1 as decimal(5, 0)) < cast(1 as smallint) FROM t +-- !query 750 schema +struct<(CAST(1 AS DECIMAL(5,0)) < CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 750 output +false + + +-- !query 751 +SELECT cast(1 as decimal(10, 0)) < cast(1 as smallint) FROM t +-- !query 751 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 751 output +false + + +-- !query 752 +SELECT cast(1 as decimal(20, 0)) < cast(1 as smallint) FROM t +-- !query 752 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 752 output +false + + +-- !query 753 +SELECT cast(1 as decimal(3, 0)) < cast(1 as int) FROM t +-- !query 753 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 753 output +false + + +-- !query 754 +SELECT cast(1 as decimal(5, 0)) < cast(1 as int) FROM t +-- !query 754 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 754 output +false + + +-- !query 755 +SELECT cast(1 as decimal(10, 0)) < cast(1 as int) FROM t +-- !query 755 schema +struct<(CAST(1 AS DECIMAL(10,0)) < CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 755 output +false + + +-- !query 756 +SELECT cast(1 as decimal(20, 0)) < cast(1 as int) FROM t +-- !query 756 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 756 output +false + + +-- !query 757 +SELECT cast(1 as decimal(3, 0)) < cast(1 as bigint) FROM t +-- !query 757 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 757 output +false + + +-- !query 758 +SELECT cast(1 as decimal(5, 0)) < cast(1 as bigint) FROM t +-- !query 758 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 758 output +false + + +-- !query 759 +SELECT cast(1 as decimal(10, 0)) < cast(1 as bigint) FROM t +-- !query 759 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) < CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 759 output +false + + +-- !query 760 +SELECT cast(1 as decimal(20, 0)) < cast(1 as bigint) FROM t +-- !query 760 schema +struct<(CAST(1 AS DECIMAL(20,0)) < CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 760 output +false + + +-- !query 761 +SELECT cast(1 as decimal(3, 0)) < cast(1 as float) FROM t +-- !query 761 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 761 output +false + + +-- !query 762 +SELECT cast(1 as decimal(5, 0)) < cast(1 as float) FROM t +-- !query 762 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 762 output +false + + +-- !query 763 +SELECT cast(1 as decimal(10, 0)) < cast(1 as float) FROM t +-- !query 763 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 763 output +false + + +-- !query 764 +SELECT cast(1 as decimal(20, 0)) < cast(1 as float) FROM t +-- !query 764 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 764 output +false + + +-- !query 765 +SELECT cast(1 as decimal(3, 0)) < cast(1 as double) FROM t +-- !query 765 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 765 output +false + + +-- !query 766 +SELECT cast(1 as decimal(5, 0)) < cast(1 as double) FROM t +-- !query 766 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 766 output +false + + +-- !query 767 +SELECT cast(1 as decimal(10, 0)) < cast(1 as double) FROM t +-- !query 767 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 767 output +false + + +-- !query 768 +SELECT cast(1 as decimal(20, 0)) < cast(1 as double) FROM t +-- !query 768 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(1 AS DOUBLE)):boolean> +-- !query 768 output +false + + +-- !query 769 +SELECT cast(1 as decimal(3, 0)) < cast(1 as decimal(10, 0)) FROM t +-- !query 769 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 769 output +false + + +-- !query 770 +SELECT cast(1 as decimal(5, 0)) < cast(1 as decimal(10, 0)) FROM t +-- !query 770 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 770 output +false + + +-- !query 771 +SELECT cast(1 as decimal(10, 0)) < cast(1 as decimal(10, 0)) FROM t +-- !query 771 schema +struct<(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 771 output +false + + +-- !query 772 +SELECT cast(1 as decimal(20, 0)) < cast(1 as decimal(10, 0)) FROM t +-- !query 772 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) < CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 772 output +false + + +-- !query 773 +SELECT cast(1 as decimal(3, 0)) < cast(1 as string) FROM t +-- !query 773 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 773 output +false + + +-- !query 774 +SELECT cast(1 as decimal(5, 0)) < cast(1 as string) FROM t +-- !query 774 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 774 output +false + + +-- !query 775 +SELECT cast(1 as decimal(10, 0)) < cast(1 as string) FROM t +-- !query 775 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 775 output +false + + +-- !query 776 +SELECT cast(1 as decimal(20, 0)) < cast(1 as string) FROM t +-- !query 776 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) < CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 776 output +false + + +-- !query 777 +SELECT cast(1 as decimal(3, 0)) < cast('1' as binary) FROM t +-- !query 777 schema +struct<> +-- !query 777 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 778 +SELECT cast(1 as decimal(5, 0)) < cast('1' as binary) FROM t +-- !query 778 schema +struct<> +-- !query 778 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 779 +SELECT cast(1 as decimal(10, 0)) < cast('1' as binary) FROM t +-- !query 779 schema +struct<> +-- !query 779 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 780 +SELECT cast(1 as decimal(20, 0)) < cast('1' as binary) FROM t +-- !query 780 schema +struct<> +-- !query 780 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 781 +SELECT cast(1 as decimal(3, 0)) < cast(1 as boolean) FROM t +-- !query 781 schema +struct<> +-- !query 781 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 782 +SELECT cast(1 as decimal(5, 0)) < cast(1 as boolean) FROM t +-- !query 782 schema +struct<> +-- !query 782 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 783 +SELECT cast(1 as decimal(10, 0)) < cast(1 as boolean) FROM t +-- !query 783 schema +struct<> +-- !query 783 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 784 +SELECT cast(1 as decimal(20, 0)) < cast(1 as boolean) FROM t +-- !query 784 schema +struct<> +-- !query 784 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 785 +SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 785 schema +struct<> +-- !query 785 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 786 +SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 786 schema +struct<> +-- !query 786 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 787 +SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 787 schema +struct<> +-- !query 787 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 788 +SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 788 schema +struct<> +-- !query 788 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 789 +SELECT cast(1 as decimal(3, 0)) < cast('2017-12-11 09:30:00' as date) FROM t +-- !query 789 schema +struct<> +-- !query 789 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 790 +SELECT cast(1 as decimal(5, 0)) < cast('2017-12-11 09:30:00' as date) FROM t +-- !query 790 schema +struct<> +-- !query 790 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 791 +SELECT cast(1 as decimal(10, 0)) < cast('2017-12-11 09:30:00' as date) FROM t +-- !query 791 schema +struct<> +-- !query 791 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 792 +SELECT cast(1 as decimal(20, 0)) < cast('2017-12-11 09:30:00' as date) FROM t +-- !query 792 schema +struct<> +-- !query 792 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) < CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 793 +SELECT cast(1 as tinyint) <= cast(1 as decimal(3, 0)) FROM t +-- !query 793 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) <= CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 793 output +true + + +-- !query 794 +SELECT cast(1 as tinyint) <= cast(1 as decimal(5, 0)) FROM t +-- !query 794 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 794 output +true + + +-- !query 795 +SELECT cast(1 as tinyint) <= cast(1 as decimal(10, 0)) FROM t +-- !query 795 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 795 output +true + + +-- !query 796 +SELECT cast(1 as tinyint) <= cast(1 as decimal(20, 0)) FROM t +-- !query 796 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 796 output +true + + +-- !query 797 +SELECT cast(1 as smallint) <= cast(1 as decimal(3, 0)) FROM t +-- !query 797 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 797 output +true + + +-- !query 798 +SELECT cast(1 as smallint) <= cast(1 as decimal(5, 0)) FROM t +-- !query 798 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) <= CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 798 output +true + + +-- !query 799 +SELECT cast(1 as smallint) <= cast(1 as decimal(10, 0)) FROM t +-- !query 799 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 799 output +true + + +-- !query 800 +SELECT cast(1 as smallint) <= cast(1 as decimal(20, 0)) FROM t +-- !query 800 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 800 output +true + + +-- !query 801 +SELECT cast(1 as int) <= cast(1 as decimal(3, 0)) FROM t +-- !query 801 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 801 output +true + + +-- !query 802 +SELECT cast(1 as int) <= cast(1 as decimal(5, 0)) FROM t +-- !query 802 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 802 output +true + + +-- !query 803 +SELECT cast(1 as int) <= cast(1 as decimal(10, 0)) FROM t +-- !query 803 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 803 output +true + + +-- !query 804 +SELECT cast(1 as int) <= cast(1 as decimal(20, 0)) FROM t +-- !query 804 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 804 output +true + + +-- !query 805 +SELECT cast(1 as bigint) <= cast(1 as decimal(3, 0)) FROM t +-- !query 805 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 805 output +true + + +-- !query 806 +SELECT cast(1 as bigint) <= cast(1 as decimal(5, 0)) FROM t +-- !query 806 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 806 output +true + + +-- !query 807 +SELECT cast(1 as bigint) <= cast(1 as decimal(10, 0)) FROM t +-- !query 807 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 807 output +true + + +-- !query 808 +SELECT cast(1 as bigint) <= cast(1 as decimal(20, 0)) FROM t +-- !query 808 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) <= CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 808 output +true + + +-- !query 809 +SELECT cast(1 as float) <= cast(1 as decimal(3, 0)) FROM t +-- !query 809 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 809 output +true + + +-- !query 810 +SELECT cast(1 as float) <= cast(1 as decimal(5, 0)) FROM t +-- !query 810 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 810 output +true + + +-- !query 811 +SELECT cast(1 as float) <= cast(1 as decimal(10, 0)) FROM t +-- !query 811 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 811 output +true + + +-- !query 812 +SELECT cast(1 as float) <= cast(1 as decimal(20, 0)) FROM t +-- !query 812 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 812 output +true + + +-- !query 813 +SELECT cast(1 as double) <= cast(1 as decimal(3, 0)) FROM t +-- !query 813 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 813 output +true + + +-- !query 814 +SELECT cast(1 as double) <= cast(1 as decimal(5, 0)) FROM t +-- !query 814 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 814 output +true + + +-- !query 815 +SELECT cast(1 as double) <= cast(1 as decimal(10, 0)) FROM t +-- !query 815 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 815 output +true + + +-- !query 816 +SELECT cast(1 as double) <= cast(1 as decimal(20, 0)) FROM t +-- !query 816 schema +struct<(CAST(1 AS DOUBLE) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 816 output +true + + +-- !query 817 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(3, 0)) FROM t +-- !query 817 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 817 output +true + + +-- !query 818 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(5, 0)) FROM t +-- !query 818 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 818 output +true + + +-- !query 819 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t +-- !query 819 schema +struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 819 output +true + + +-- !query 820 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(20, 0)) FROM t +-- !query 820 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 820 output +true + + +-- !query 821 +SELECT cast('1' as binary) <= cast(1 as decimal(3, 0)) FROM t +-- !query 821 schema +struct<> +-- !query 821 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 822 +SELECT cast('1' as binary) <= cast(1 as decimal(5, 0)) FROM t +-- !query 822 schema +struct<> +-- !query 822 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 823 +SELECT cast('1' as binary) <= cast(1 as decimal(10, 0)) FROM t +-- !query 823 schema +struct<> +-- !query 823 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 824 +SELECT cast('1' as binary) <= cast(1 as decimal(20, 0)) FROM t +-- !query 824 schema +struct<> +-- !query 824 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) <= CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 825 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(3, 0)) FROM t +-- !query 825 schema +struct<> +-- !query 825 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 826 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(5, 0)) FROM t +-- !query 826 schema +struct<> +-- !query 826 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 827 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(10, 0)) FROM t +-- !query 827 schema +struct<> +-- !query 827 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 828 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <= cast(1 as decimal(20, 0)) FROM t +-- !query 828 schema +struct<> +-- !query 828 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) <= CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 829 +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(3, 0)) FROM t +-- !query 829 schema +struct<> +-- !query 829 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 830 +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(5, 0)) FROM t +-- !query 830 schema +struct<> +-- !query 830 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 831 +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(10, 0)) FROM t +-- !query 831 schema +struct<> +-- !query 831 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 832 +SELECT cast('2017-12-11 09:30:00' as date) <= cast(1 as decimal(20, 0)) FROM t +-- !query 832 schema +struct<> +-- !query 832 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) <= CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 833 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as tinyint) FROM t +-- !query 833 schema +struct<(CAST(1 AS DECIMAL(3,0)) <= CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 833 output +true + + +-- !query 834 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as tinyint) FROM t +-- !query 834 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 834 output +true + + +-- !query 835 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as tinyint) FROM t +-- !query 835 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 835 output +true + + +-- !query 836 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as tinyint) FROM t +-- !query 836 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 836 output +true + + +-- !query 837 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as smallint) FROM t +-- !query 837 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 837 output +true + + +-- !query 838 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as smallint) FROM t +-- !query 838 schema +struct<(CAST(1 AS DECIMAL(5,0)) <= CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 838 output +true + + +-- !query 839 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as smallint) FROM t +-- !query 839 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 839 output +true + + +-- !query 840 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as smallint) FROM t +-- !query 840 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 840 output +true + + +-- !query 841 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as int) FROM t +-- !query 841 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 841 output +true + + +-- !query 842 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as int) FROM t +-- !query 842 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 842 output +true + + +-- !query 843 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as int) FROM t +-- !query 843 schema +struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 843 output +true + + +-- !query 844 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as int) FROM t +-- !query 844 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 844 output +true + + +-- !query 845 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as bigint) FROM t +-- !query 845 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 845 output +true + + +-- !query 846 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as bigint) FROM t +-- !query 846 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 846 output +true + + +-- !query 847 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as bigint) FROM t +-- !query 847 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) <= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 847 output +true + + +-- !query 848 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as bigint) FROM t +-- !query 848 schema +struct<(CAST(1 AS DECIMAL(20,0)) <= CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 848 output +true + + +-- !query 849 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as float) FROM t +-- !query 849 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 849 output +true + + +-- !query 850 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as float) FROM t +-- !query 850 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 850 output +true + + +-- !query 851 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as float) FROM t +-- !query 851 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 851 output +true + + +-- !query 852 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as float) FROM t +-- !query 852 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 852 output +true + + +-- !query 853 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as double) FROM t +-- !query 853 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 853 output +true + + +-- !query 854 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as double) FROM t +-- !query 854 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 854 output +true + + +-- !query 855 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as double) FROM t +-- !query 855 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 855 output +true + + +-- !query 856 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as double) FROM t +-- !query 856 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(1 AS DOUBLE)):boolean> +-- !query 856 output +true + + +-- !query 857 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as decimal(10, 0)) FROM t +-- !query 857 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 857 output +true + + +-- !query 858 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as decimal(10, 0)) FROM t +-- !query 858 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 858 output +true + + +-- !query 859 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as decimal(10, 0)) FROM t +-- !query 859 schema +struct<(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 859 output +true + + +-- !query 860 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as decimal(10, 0)) FROM t +-- !query 860 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) <= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 860 output +true + + +-- !query 861 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as string) FROM t +-- !query 861 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 861 output +true + + +-- !query 862 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as string) FROM t +-- !query 862 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 862 output +true + + +-- !query 863 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as string) FROM t +-- !query 863 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 863 output +true + + +-- !query 864 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as string) FROM t +-- !query 864 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) <= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 864 output +true + + +-- !query 865 +SELECT cast(1 as decimal(3, 0)) <= cast('1' as binary) FROM t +-- !query 865 schema +struct<> +-- !query 865 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 866 +SELECT cast(1 as decimal(5, 0)) <= cast('1' as binary) FROM t +-- !query 866 schema +struct<> +-- !query 866 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 867 +SELECT cast(1 as decimal(10, 0)) <= cast('1' as binary) FROM t +-- !query 867 schema +struct<> +-- !query 867 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 868 +SELECT cast(1 as decimal(20, 0)) <= cast('1' as binary) FROM t +-- !query 868 schema +struct<> +-- !query 868 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 869 +SELECT cast(1 as decimal(3, 0)) <= cast(1 as boolean) FROM t +-- !query 869 schema +struct<> +-- !query 869 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 870 +SELECT cast(1 as decimal(5, 0)) <= cast(1 as boolean) FROM t +-- !query 870 schema +struct<> +-- !query 870 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 871 +SELECT cast(1 as decimal(10, 0)) <= cast(1 as boolean) FROM t +-- !query 871 schema +struct<> +-- !query 871 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 872 +SELECT cast(1 as decimal(20, 0)) <= cast(1 as boolean) FROM t +-- !query 872 schema +struct<> +-- !query 872 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 873 +SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 873 schema +struct<> +-- !query 873 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 874 +SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 874 schema +struct<> +-- !query 874 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 875 +SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 875 schema +struct<> +-- !query 875 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 876 +SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 876 schema +struct<> +-- !query 876 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 877 +SELECT cast(1 as decimal(3, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 877 schema +struct<> +-- !query 877 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 878 +SELECT cast(1 as decimal(5, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 878 schema +struct<> +-- !query 878 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 879 +SELECT cast(1 as decimal(10, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 879 schema +struct<> +-- !query 879 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 880 +SELECT cast(1 as decimal(20, 0)) <= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 880 schema +struct<> +-- !query 880 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) <= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 881 +SELECT cast(1 as tinyint) > cast(1 as decimal(3, 0)) FROM t +-- !query 881 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) > CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 881 output +false + + +-- !query 882 +SELECT cast(1 as tinyint) > cast(1 as decimal(5, 0)) FROM t +-- !query 882 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 882 output +false + + +-- !query 883 +SELECT cast(1 as tinyint) > cast(1 as decimal(10, 0)) FROM t +-- !query 883 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 883 output +false + + +-- !query 884 +SELECT cast(1 as tinyint) > cast(1 as decimal(20, 0)) FROM t +-- !query 884 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 884 output +false + + +-- !query 885 +SELECT cast(1 as smallint) > cast(1 as decimal(3, 0)) FROM t +-- !query 885 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 885 output +false + + +-- !query 886 +SELECT cast(1 as smallint) > cast(1 as decimal(5, 0)) FROM t +-- !query 886 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) > CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 886 output +false + + +-- !query 887 +SELECT cast(1 as smallint) > cast(1 as decimal(10, 0)) FROM t +-- !query 887 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 887 output +false + + +-- !query 888 +SELECT cast(1 as smallint) > cast(1 as decimal(20, 0)) FROM t +-- !query 888 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 888 output +false + + +-- !query 889 +SELECT cast(1 as int) > cast(1 as decimal(3, 0)) FROM t +-- !query 889 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 889 output +false + + +-- !query 890 +SELECT cast(1 as int) > cast(1 as decimal(5, 0)) FROM t +-- !query 890 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 890 output +false + + +-- !query 891 +SELECT cast(1 as int) > cast(1 as decimal(10, 0)) FROM t +-- !query 891 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 891 output +false + + +-- !query 892 +SELECT cast(1 as int) > cast(1 as decimal(20, 0)) FROM t +-- !query 892 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 892 output +false + + +-- !query 893 +SELECT cast(1 as bigint) > cast(1 as decimal(3, 0)) FROM t +-- !query 893 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 893 output +false + + +-- !query 894 +SELECT cast(1 as bigint) > cast(1 as decimal(5, 0)) FROM t +-- !query 894 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 894 output +false + + +-- !query 895 +SELECT cast(1 as bigint) > cast(1 as decimal(10, 0)) FROM t +-- !query 895 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 895 output +false + + +-- !query 896 +SELECT cast(1 as bigint) > cast(1 as decimal(20, 0)) FROM t +-- !query 896 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) > CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 896 output +false + + +-- !query 897 +SELECT cast(1 as float) > cast(1 as decimal(3, 0)) FROM t +-- !query 897 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 897 output +false + + +-- !query 898 +SELECT cast(1 as float) > cast(1 as decimal(5, 0)) FROM t +-- !query 898 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 898 output +false + + +-- !query 899 +SELECT cast(1 as float) > cast(1 as decimal(10, 0)) FROM t +-- !query 899 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 899 output +false + + +-- !query 900 +SELECT cast(1 as float) > cast(1 as decimal(20, 0)) FROM t +-- !query 900 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) > CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 900 output +false + + +-- !query 901 +SELECT cast(1 as double) > cast(1 as decimal(3, 0)) FROM t +-- !query 901 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 901 output +false + + +-- !query 902 +SELECT cast(1 as double) > cast(1 as decimal(5, 0)) FROM t +-- !query 902 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 902 output +false + + +-- !query 903 +SELECT cast(1 as double) > cast(1 as decimal(10, 0)) FROM t +-- !query 903 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 903 output +false + + +-- !query 904 +SELECT cast(1 as double) > cast(1 as decimal(20, 0)) FROM t +-- !query 904 schema +struct<(CAST(1 AS DOUBLE) > CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 904 output +false + + +-- !query 905 +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(3, 0)) FROM t +-- !query 905 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 905 output +false + + +-- !query 906 +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(5, 0)) FROM t +-- !query 906 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 906 output +false + + +-- !query 907 +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t +-- !query 907 schema +struct<(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 907 output +false + + +-- !query 908 +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(20, 0)) FROM t +-- !query 908 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 908 output +false + + +-- !query 909 +SELECT cast('1' as binary) > cast(1 as decimal(3, 0)) FROM t +-- !query 909 schema +struct<> +-- !query 909 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 910 +SELECT cast('1' as binary) > cast(1 as decimal(5, 0)) FROM t +-- !query 910 schema +struct<> +-- !query 910 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 911 +SELECT cast('1' as binary) > cast(1 as decimal(10, 0)) FROM t +-- !query 911 schema +struct<> +-- !query 911 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 912 +SELECT cast('1' as binary) > cast(1 as decimal(20, 0)) FROM t +-- !query 912 schema +struct<> +-- !query 912 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) > CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 913 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(3, 0)) FROM t +-- !query 913 schema +struct<> +-- !query 913 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 914 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(5, 0)) FROM t +-- !query 914 schema +struct<> +-- !query 914 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 915 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(10, 0)) FROM t +-- !query 915 schema +struct<> +-- !query 915 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 916 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) > cast(1 as decimal(20, 0)) FROM t +-- !query 916 schema +struct<> +-- !query 916 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) > CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 917 +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(3, 0)) FROM t +-- !query 917 schema +struct<> +-- !query 917 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 918 +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(5, 0)) FROM t +-- !query 918 schema +struct<> +-- !query 918 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 919 +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(10, 0)) FROM t +-- !query 919 schema +struct<> +-- !query 919 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 920 +SELECT cast('2017-12-11 09:30:00' as date) > cast(1 as decimal(20, 0)) FROM t +-- !query 920 schema +struct<> +-- !query 920 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) > CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 921 +SELECT cast(1 as decimal(3, 0)) > cast(1 as tinyint) FROM t +-- !query 921 schema +struct<(CAST(1 AS DECIMAL(3,0)) > CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 921 output +false + + +-- !query 922 +SELECT cast(1 as decimal(5, 0)) > cast(1 as tinyint) FROM t +-- !query 922 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 922 output +false + + +-- !query 923 +SELECT cast(1 as decimal(10, 0)) > cast(1 as tinyint) FROM t +-- !query 923 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 923 output +false + + +-- !query 924 +SELECT cast(1 as decimal(20, 0)) > cast(1 as tinyint) FROM t +-- !query 924 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 924 output +false + + +-- !query 925 +SELECT cast(1 as decimal(3, 0)) > cast(1 as smallint) FROM t +-- !query 925 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 925 output +false + + +-- !query 926 +SELECT cast(1 as decimal(5, 0)) > cast(1 as smallint) FROM t +-- !query 926 schema +struct<(CAST(1 AS DECIMAL(5,0)) > CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 926 output +false + + +-- !query 927 +SELECT cast(1 as decimal(10, 0)) > cast(1 as smallint) FROM t +-- !query 927 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 927 output +false + + +-- !query 928 +SELECT cast(1 as decimal(20, 0)) > cast(1 as smallint) FROM t +-- !query 928 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 928 output +false + + +-- !query 929 +SELECT cast(1 as decimal(3, 0)) > cast(1 as int) FROM t +-- !query 929 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 929 output +false + + +-- !query 930 +SELECT cast(1 as decimal(5, 0)) > cast(1 as int) FROM t +-- !query 930 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 930 output +false + + +-- !query 931 +SELECT cast(1 as decimal(10, 0)) > cast(1 as int) FROM t +-- !query 931 schema +struct<(CAST(1 AS DECIMAL(10,0)) > CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 931 output +false + + +-- !query 932 +SELECT cast(1 as decimal(20, 0)) > cast(1 as int) FROM t +-- !query 932 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 932 output +false + + +-- !query 933 +SELECT cast(1 as decimal(3, 0)) > cast(1 as bigint) FROM t +-- !query 933 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 933 output +false + + +-- !query 934 +SELECT cast(1 as decimal(5, 0)) > cast(1 as bigint) FROM t +-- !query 934 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 934 output +false + + +-- !query 935 +SELECT cast(1 as decimal(10, 0)) > cast(1 as bigint) FROM t +-- !query 935 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) > CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 935 output +false + + +-- !query 936 +SELECT cast(1 as decimal(20, 0)) > cast(1 as bigint) FROM t +-- !query 936 schema +struct<(CAST(1 AS DECIMAL(20,0)) > CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 936 output +false + + +-- !query 937 +SELECT cast(1 as decimal(3, 0)) > cast(1 as float) FROM t +-- !query 937 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 937 output +false + + +-- !query 938 +SELECT cast(1 as decimal(5, 0)) > cast(1 as float) FROM t +-- !query 938 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 938 output +false + + +-- !query 939 +SELECT cast(1 as decimal(10, 0)) > cast(1 as float) FROM t +-- !query 939 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 939 output +false + + +-- !query 940 +SELECT cast(1 as decimal(20, 0)) > cast(1 as float) FROM t +-- !query 940 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 940 output +false + + +-- !query 941 +SELECT cast(1 as decimal(3, 0)) > cast(1 as double) FROM t +-- !query 941 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 941 output +false + + +-- !query 942 +SELECT cast(1 as decimal(5, 0)) > cast(1 as double) FROM t +-- !query 942 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 942 output +false + + +-- !query 943 +SELECT cast(1 as decimal(10, 0)) > cast(1 as double) FROM t +-- !query 943 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 943 output +false + + +-- !query 944 +SELECT cast(1 as decimal(20, 0)) > cast(1 as double) FROM t +-- !query 944 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(1 AS DOUBLE)):boolean> +-- !query 944 output +false + + +-- !query 945 +SELECT cast(1 as decimal(3, 0)) > cast(1 as decimal(10, 0)) FROM t +-- !query 945 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 945 output +false + + +-- !query 946 +SELECT cast(1 as decimal(5, 0)) > cast(1 as decimal(10, 0)) FROM t +-- !query 946 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 946 output +false + + +-- !query 947 +SELECT cast(1 as decimal(10, 0)) > cast(1 as decimal(10, 0)) FROM t +-- !query 947 schema +struct<(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 947 output +false + + +-- !query 948 +SELECT cast(1 as decimal(20, 0)) > cast(1 as decimal(10, 0)) FROM t +-- !query 948 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) > CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 948 output +false + + +-- !query 949 +SELECT cast(1 as decimal(3, 0)) > cast(1 as string) FROM t +-- !query 949 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 949 output +false + + +-- !query 950 +SELECT cast(1 as decimal(5, 0)) > cast(1 as string) FROM t +-- !query 950 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 950 output +false + + +-- !query 951 +SELECT cast(1 as decimal(10, 0)) > cast(1 as string) FROM t +-- !query 951 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 951 output +false + + +-- !query 952 +SELECT cast(1 as decimal(20, 0)) > cast(1 as string) FROM t +-- !query 952 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) > CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 952 output +false + + +-- !query 953 +SELECT cast(1 as decimal(3, 0)) > cast('1' as binary) FROM t +-- !query 953 schema +struct<> +-- !query 953 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 954 +SELECT cast(1 as decimal(5, 0)) > cast('1' as binary) FROM t +-- !query 954 schema +struct<> +-- !query 954 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 955 +SELECT cast(1 as decimal(10, 0)) > cast('1' as binary) FROM t +-- !query 955 schema +struct<> +-- !query 955 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 956 +SELECT cast(1 as decimal(20, 0)) > cast('1' as binary) FROM t +-- !query 956 schema +struct<> +-- !query 956 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 957 +SELECT cast(1 as decimal(3, 0)) > cast(1 as boolean) FROM t +-- !query 957 schema +struct<> +-- !query 957 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 958 +SELECT cast(1 as decimal(5, 0)) > cast(1 as boolean) FROM t +-- !query 958 schema +struct<> +-- !query 958 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 959 +SELECT cast(1 as decimal(10, 0)) > cast(1 as boolean) FROM t +-- !query 959 schema +struct<> +-- !query 959 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 960 +SELECT cast(1 as decimal(20, 0)) > cast(1 as boolean) FROM t +-- !query 960 schema +struct<> +-- !query 960 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 961 +SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 961 schema +struct<> +-- !query 961 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 962 +SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 962 schema +struct<> +-- !query 962 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 963 +SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 963 schema +struct<> +-- !query 963 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 964 +SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 964 schema +struct<> +-- !query 964 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 965 +SELECT cast(1 as decimal(3, 0)) > cast('2017-12-11 09:30:00' as date) FROM t +-- !query 965 schema +struct<> +-- !query 965 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 966 +SELECT cast(1 as decimal(5, 0)) > cast('2017-12-11 09:30:00' as date) FROM t +-- !query 966 schema +struct<> +-- !query 966 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 967 +SELECT cast(1 as decimal(10, 0)) > cast('2017-12-11 09:30:00' as date) FROM t +-- !query 967 schema +struct<> +-- !query 967 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 968 +SELECT cast(1 as decimal(20, 0)) > cast('2017-12-11 09:30:00' as date) FROM t +-- !query 968 schema +struct<> +-- !query 968 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) > CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 969 +SELECT cast(1 as tinyint) >= cast(1 as decimal(3, 0)) FROM t +-- !query 969 schema +struct<(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) >= CAST(1 AS DECIMAL(3,0))):boolean> +-- !query 969 output +true + + +-- !query 970 +SELECT cast(1 as tinyint) >= cast(1 as decimal(5, 0)) FROM t +-- !query 970 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 970 output +true + + +-- !query 971 +SELECT cast(1 as tinyint) >= cast(1 as decimal(10, 0)) FROM t +-- !query 971 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 971 output +true + + +-- !query 972 +SELECT cast(1 as tinyint) >= cast(1 as decimal(20, 0)) FROM t +-- !query 972 schema +struct<(CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 972 output +true + + +-- !query 973 +SELECT cast(1 as smallint) >= cast(1 as decimal(3, 0)) FROM t +-- !query 973 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 973 output +true + + +-- !query 974 +SELECT cast(1 as smallint) >= cast(1 as decimal(5, 0)) FROM t +-- !query 974 schema +struct<(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) >= CAST(1 AS DECIMAL(5,0))):boolean> +-- !query 974 output +true + + +-- !query 975 +SELECT cast(1 as smallint) >= cast(1 as decimal(10, 0)) FROM t +-- !query 975 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 975 output +true + + +-- !query 976 +SELECT cast(1 as smallint) >= cast(1 as decimal(20, 0)) FROM t +-- !query 976 schema +struct<(CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 976 output +true + + +-- !query 977 +SELECT cast(1 as int) >= cast(1 as decimal(3, 0)) FROM t +-- !query 977 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 977 output +true + + +-- !query 978 +SELECT cast(1 as int) >= cast(1 as decimal(5, 0)) FROM t +-- !query 978 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 978 output +true + + +-- !query 979 +SELECT cast(1 as int) >= cast(1 as decimal(10, 0)) FROM t +-- !query 979 schema +struct<(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 979 output +true + + +-- !query 980 +SELECT cast(1 as int) >= cast(1 as decimal(20, 0)) FROM t +-- !query 980 schema +struct<(CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 980 output +true + + +-- !query 981 +SELECT cast(1 as bigint) >= cast(1 as decimal(3, 0)) FROM t +-- !query 981 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 981 output +true + + +-- !query 982 +SELECT cast(1 as bigint) >= cast(1 as decimal(5, 0)) FROM t +-- !query 982 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 982 output +true + + +-- !query 983 +SELECT cast(1 as bigint) >= cast(1 as decimal(10, 0)) FROM t +-- !query 983 schema +struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 983 output +true + + +-- !query 984 +SELECT cast(1 as bigint) >= cast(1 as decimal(20, 0)) FROM t +-- !query 984 schema +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) >= CAST(1 AS DECIMAL(20,0))):boolean> +-- !query 984 output +true + + +-- !query 985 +SELECT cast(1 as float) >= cast(1 as decimal(3, 0)) FROM t +-- !query 985 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 985 output +true + + +-- !query 986 +SELECT cast(1 as float) >= cast(1 as decimal(5, 0)) FROM t +-- !query 986 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 986 output +true + + +-- !query 987 +SELECT cast(1 as float) >= cast(1 as decimal(10, 0)) FROM t +-- !query 987 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 987 output +true + + +-- !query 988 +SELECT cast(1 as float) >= cast(1 as decimal(20, 0)) FROM t +-- !query 988 schema +struct<(CAST(CAST(1 AS FLOAT) AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 988 output +true + + +-- !query 989 +SELECT cast(1 as double) >= cast(1 as decimal(3, 0)) FROM t +-- !query 989 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE)):boolean> +-- !query 989 output +true + + +-- !query 990 +SELECT cast(1 as double) >= cast(1 as decimal(5, 0)) FROM t +-- !query 990 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE)):boolean> +-- !query 990 output +true + + +-- !query 991 +SELECT cast(1 as double) >= cast(1 as decimal(10, 0)) FROM t +-- !query 991 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE)):boolean> +-- !query 991 output +true + + +-- !query 992 +SELECT cast(1 as double) >= cast(1 as decimal(20, 0)) FROM t +-- !query 992 schema +struct<(CAST(1 AS DOUBLE) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE)):boolean> +-- !query 992 output +true + + +-- !query 993 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(3, 0)) FROM t +-- !query 993 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 993 output +true + + +-- !query 994 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(5, 0)) FROM t +-- !query 994 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 994 output +true + + +-- !query 995 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t +-- !query 995 schema +struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 995 output +true + + +-- !query 996 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(20, 0)) FROM t +-- !query 996 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 996 output +true + + +-- !query 997 +SELECT cast('1' as binary) >= cast(1 as decimal(3, 0)) FROM t +-- !query 997 schema +struct<> +-- !query 997 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 998 +SELECT cast('1' as binary) >= cast(1 as decimal(5, 0)) FROM t +-- !query 998 schema +struct<> +-- !query 998 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 999 +SELECT cast('1' as binary) >= cast(1 as decimal(10, 0)) FROM t +-- !query 999 schema +struct<> +-- !query 999 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 1000 +SELECT cast('1' as binary) >= cast(1 as decimal(20, 0)) FROM t +-- !query 1000 schema +struct<> +-- !query 1000 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) >= CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 1001 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(3, 0)) FROM t +-- !query 1001 schema +struct<> +-- !query 1001 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 1002 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(5, 0)) FROM t +-- !query 1002 schema +struct<> +-- !query 1002 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 1003 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1003 schema +struct<> +-- !query 1003 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 1004 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) >= cast(1 as decimal(20, 0)) FROM t +-- !query 1004 schema +struct<> +-- !query 1004 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) >= CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 1005 +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(3, 0)) FROM t +-- !query 1005 schema +struct<> +-- !query 1005 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 1006 +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(5, 0)) FROM t +-- !query 1006 schema +struct<> +-- !query 1006 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 1007 +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1007 schema +struct<> +-- !query 1007 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 1008 +SELECT cast('2017-12-11 09:30:00' as date) >= cast(1 as decimal(20, 0)) FROM t +-- !query 1008 schema +struct<> +-- !query 1008 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) >= CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 1009 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as tinyint) FROM t +-- !query 1009 schema +struct<(CAST(1 AS DECIMAL(3,0)) >= CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0))):boolean> +-- !query 1009 output +true + + +-- !query 1010 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as tinyint) FROM t +-- !query 1010 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0))):boolean> +-- !query 1010 output +true + + +-- !query 1011 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as tinyint) FROM t +-- !query 1011 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0))):boolean> +-- !query 1011 output +true + + +-- !query 1012 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as tinyint) FROM t +-- !query 1012 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0))):boolean> +-- !query 1012 output +true + + +-- !query 1013 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as smallint) FROM t +-- !query 1013 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0))):boolean> +-- !query 1013 output +true + + +-- !query 1014 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as smallint) FROM t +-- !query 1014 schema +struct<(CAST(1 AS DECIMAL(5,0)) >= CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0))):boolean> +-- !query 1014 output +true + + +-- !query 1015 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as smallint) FROM t +-- !query 1015 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0))):boolean> +-- !query 1015 output +true + + +-- !query 1016 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as smallint) FROM t +-- !query 1016 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0))):boolean> +-- !query 1016 output +true + + +-- !query 1017 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as int) FROM t +-- !query 1017 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 1017 output +true + + +-- !query 1018 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as int) FROM t +-- !query 1018 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 1018 output +true + + +-- !query 1019 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as int) FROM t +-- !query 1019 schema +struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(CAST(1 AS INT) AS DECIMAL(10,0))):boolean> +-- !query 1019 output +true + + +-- !query 1020 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as int) FROM t +-- !query 1020 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 1020 output +true + + +-- !query 1021 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as bigint) FROM t +-- !query 1021 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 1021 output +true + + +-- !query 1022 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as bigint) FROM t +-- !query 1022 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 1022 output +true + + +-- !query 1023 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as bigint) FROM t +-- !query 1023 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) >= CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0))):boolean> +-- !query 1023 output +true + + +-- !query 1024 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as bigint) FROM t +-- !query 1024 schema +struct<(CAST(1 AS DECIMAL(20,0)) >= CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):boolean> +-- !query 1024 output +true + + +-- !query 1025 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as float) FROM t +-- !query 1025 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 1025 output +true + + +-- !query 1026 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as float) FROM t +-- !query 1026 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 1026 output +true + + +-- !query 1027 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as float) FROM t +-- !query 1027 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 1027 output +true + + +-- !query 1028 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as float) FROM t +-- !query 1028 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(CAST(1 AS FLOAT) AS DOUBLE)):boolean> +-- !query 1028 output +true + + +-- !query 1029 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as double) FROM t +-- !query 1029 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 1029 output +true + + +-- !query 1030 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as double) FROM t +-- !query 1030 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 1030 output +true + + +-- !query 1031 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as double) FROM t +-- !query 1031 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 1031 output +true + + +-- !query 1032 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as double) FROM t +-- !query 1032 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(1 AS DOUBLE)):boolean> +-- !query 1032 output +true + + +-- !query 1033 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1033 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 1033 output +true + + +-- !query 1034 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1034 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0))):boolean> +-- !query 1034 output +true + + +-- !query 1035 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1035 schema +struct<(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS DECIMAL(10,0))):boolean> +-- !query 1035 output +true + + +-- !query 1036 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as decimal(10, 0)) FROM t +-- !query 1036 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) >= CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0))):boolean> +-- !query 1036 output +true + + +-- !query 1037 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as string) FROM t +-- !query 1037 schema +struct<(CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 1037 output +true + + +-- !query 1038 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as string) FROM t +-- !query 1038 schema +struct<(CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 1038 output +true + + +-- !query 1039 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as string) FROM t +-- !query 1039 schema +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 1039 output +true + + +-- !query 1040 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as string) FROM t +-- !query 1040 schema +struct<(CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) >= CAST(CAST(1 AS STRING) AS DOUBLE)):boolean> +-- !query 1040 output +true + + +-- !query 1041 +SELECT cast(1 as decimal(3, 0)) >= cast('1' as binary) FROM t +-- !query 1041 schema +struct<> +-- !query 1041 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 1042 +SELECT cast(1 as decimal(5, 0)) >= cast('1' as binary) FROM t +-- !query 1042 schema +struct<> +-- !query 1042 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 1043 +SELECT cast(1 as decimal(10, 0)) >= cast('1' as binary) FROM t +-- !query 1043 schema +struct<> +-- !query 1043 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 1044 +SELECT cast(1 as decimal(20, 0)) >= cast('1' as binary) FROM t +-- !query 1044 schema +struct<> +-- !query 1044 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 1045 +SELECT cast(1 as decimal(3, 0)) >= cast(1 as boolean) FROM t +-- !query 1045 schema +struct<> +-- !query 1045 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST(1 AS BOOLEAN))' (decimal(3,0) and boolean).; line 1 pos 7 + + +-- !query 1046 +SELECT cast(1 as decimal(5, 0)) >= cast(1 as boolean) FROM t +-- !query 1046 schema +struct<> +-- !query 1046 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST(1 AS BOOLEAN))' (decimal(5,0) and boolean).; line 1 pos 7 + + +-- !query 1047 +SELECT cast(1 as decimal(10, 0)) >= cast(1 as boolean) FROM t +-- !query 1047 schema +struct<> +-- !query 1047 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST(1 AS BOOLEAN))' (decimal(10,0) and boolean).; line 1 pos 7 + + +-- !query 1048 +SELECT cast(1 as decimal(20, 0)) >= cast(1 as boolean) FROM t +-- !query 1048 schema +struct<> +-- !query 1048 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST(1 AS BOOLEAN))' (decimal(20,0) and boolean).; line 1 pos 7 + + +-- !query 1049 +SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1049 schema +struct<> +-- !query 1049 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 1050 +SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1050 schema +struct<> +-- !query 1050 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 1051 +SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1051 schema +struct<> +-- !query 1051 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 1052 +SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1052 schema +struct<> +-- !query 1052 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 1053 +SELECT cast(1 as decimal(3, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1053 schema +struct<> +-- !query 1053 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 1054 +SELECT cast(1 as decimal(5, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1054 schema +struct<> +-- !query 1054 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 1055 +SELECT cast(1 as decimal(10, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1055 schema +struct<> +-- !query 1055 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 1056 +SELECT cast(1 as decimal(20, 0)) >= cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1056 schema +struct<> +-- !query 1056 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) >= CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 + + +-- !query 1057 +SELECT cast(1 as tinyint) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1057 schema +struct<(NOT (CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) = CAST(1 AS DECIMAL(3,0)))):boolean> +-- !query 1057 output +false + + +-- !query 1058 +SELECT cast(1 as tinyint) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1058 schema +struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)))):boolean> +-- !query 1058 output +false + + +-- !query 1059 +SELECT cast(1 as tinyint) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1059 schema +struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1059 output +false + + +-- !query 1060 +SELECT cast(1 as tinyint) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1060 schema +struct<(NOT (CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1060 output +false + + +-- !query 1061 +SELECT cast(1 as smallint) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1061 schema +struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)))):boolean> +-- !query 1061 output +false + + +-- !query 1062 +SELECT cast(1 as smallint) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1062 schema +struct<(NOT (CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) = CAST(1 AS DECIMAL(5,0)))):boolean> +-- !query 1062 output +false + + +-- !query 1063 +SELECT cast(1 as smallint) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1063 schema +struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1063 output +false + + +-- !query 1064 +SELECT cast(1 as smallint) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1064 schema +struct<(NOT (CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1064 output +false + + +-- !query 1065 +SELECT cast(1 as int) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1065 schema +struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1065 output +false + + +-- !query 1066 +SELECT cast(1 as int) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1066 schema +struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1066 output +false + + +-- !query 1067 +SELECT cast(1 as int) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1067 schema +struct<(NOT (CAST(CAST(1 AS INT) AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean> +-- !query 1067 output +false + + +-- !query 1068 +SELECT cast(1 as int) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1068 schema +struct<(NOT (CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1068 output +false + + +-- !query 1069 +SELECT cast(1 as bigint) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1069 schema +struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1069 output +false + + +-- !query 1070 +SELECT cast(1 as bigint) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1070 schema +struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1070 output +false + + +-- !query 1071 +SELECT cast(1 as bigint) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1071 schema +struct<(NOT (CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1071 output +false + + +-- !query 1072 +SELECT cast(1 as bigint) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1072 schema +struct<(NOT (CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) = CAST(1 AS DECIMAL(20,0)))):boolean> +-- !query 1072 output +false + + +-- !query 1073 +SELECT cast(1 as float) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1073 schema +struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE))):boolean> +-- !query 1073 output +false + + +-- !query 1074 +SELECT cast(1 as float) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1074 schema +struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE))):boolean> +-- !query 1074 output +false + + +-- !query 1075 +SELECT cast(1 as float) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1075 schema +struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 1075 output +false + + +-- !query 1076 +SELECT cast(1 as float) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1076 schema +struct<(NOT (CAST(CAST(1 AS FLOAT) AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE))):boolean> +-- !query 1076 output +false + + +-- !query 1077 +SELECT cast(1 as double) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1077 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE))):boolean> +-- !query 1077 output +false + + +-- !query 1078 +SELECT cast(1 as double) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1078 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE))):boolean> +-- !query 1078 output +false + + +-- !query 1079 +SELECT cast(1 as double) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1079 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> +-- !query 1079 output +false + + +-- !query 1080 +SELECT cast(1 as double) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1080 schema +struct<(NOT (CAST(1 AS DOUBLE) = CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE))):boolean> +-- !query 1080 output +false + + +-- !query 1081 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1081 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1081 output +false + + +-- !query 1082 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1082 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1082 output +false + + +-- !query 1083 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1083 schema +struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean> +-- !query 1083 output +false + + +-- !query 1084 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1084 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1084 output +false + + +-- !query 1085 +SELECT cast('1' as binary) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1085 schema +struct<> +-- !query 1085 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(3,0)))' (binary and decimal(3,0)).; line 1 pos 7 + + +-- !query 1086 +SELECT cast('1' as binary) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1086 schema +struct<> +-- !query 1086 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(5,0)))' (binary and decimal(5,0)).; line 1 pos 7 + + +-- !query 1087 +SELECT cast('1' as binary) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1087 schema +struct<> +-- !query 1087 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(10,0)))' (binary and decimal(10,0)).; line 1 pos 7 + + +-- !query 1088 +SELECT cast('1' as binary) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1088 schema +struct<> +-- !query 1088 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('1' AS BINARY) = CAST(1 AS DECIMAL(20,0)))' (binary and decimal(20,0)).; line 1 pos 7 + + +-- !query 1089 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1089 schema +struct<> +-- !query 1089 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(3,0)))' (timestamp and decimal(3,0)).; line 1 pos 7 + + +-- !query 1090 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1090 schema +struct<> +-- !query 1090 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(5,0)))' (timestamp and decimal(5,0)).; line 1 pos 7 + + +-- !query 1091 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1091 schema +struct<> +-- !query 1091 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(10,0)))' (timestamp and decimal(10,0)).; line 1 pos 7 + + +-- !query 1092 +SELECT cast('2017-12-11 09:30:00.0' as timestamp) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1092 schema +struct<> +-- !query 1092 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) = CAST(1 AS DECIMAL(20,0)))' (timestamp and decimal(20,0)).; line 1 pos 7 + + +-- !query 1093 +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(3, 0)) FROM t +-- !query 1093 schema +struct<> +-- !query 1093 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(3,0)))' (date and decimal(3,0)).; line 1 pos 7 + + +-- !query 1094 +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(5, 0)) FROM t +-- !query 1094 schema +struct<> +-- !query 1094 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(5,0)))' (date and decimal(5,0)).; line 1 pos 7 + + +-- !query 1095 +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1095 schema +struct<> +-- !query 1095 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(10,0)))' (date and decimal(10,0)).; line 1 pos 7 + + +-- !query 1096 +SELECT cast('2017-12-11 09:30:00' as date) <> cast(1 as decimal(20, 0)) FROM t +-- !query 1096 schema +struct<> +-- !query 1096 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' due to data type mismatch: differing types in '(CAST('2017-12-11 09:30:00' AS DATE) = CAST(1 AS DECIMAL(20,0)))' (date and decimal(20,0)).; line 1 pos 7 + + +-- !query 1097 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as tinyint) FROM t +-- !query 1097 schema +struct<(NOT (CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)))):boolean> +-- !query 1097 output +false + + +-- !query 1098 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as tinyint) FROM t +-- !query 1098 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(5,0)))):boolean> +-- !query 1098 output +false + + +-- !query 1099 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as tinyint) FROM t +-- !query 1099 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1099 output +false + + +-- !query 1100 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as tinyint) FROM t +-- !query 1100 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS TINYINT) AS DECIMAL(3,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1100 output +false + + +-- !query 1101 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as smallint) FROM t +-- !query 1101 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(5,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(5,0)))):boolean> +-- !query 1101 output +false + + +-- !query 1102 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as smallint) FROM t +-- !query 1102 schema +struct<(NOT (CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)))):boolean> +-- !query 1102 output +false + + +-- !query 1103 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as smallint) FROM t +-- !query 1103 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1103 output +false + + +-- !query 1104 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as smallint) FROM t +-- !query 1104 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS SMALLINT) AS DECIMAL(5,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1104 output +false + + +-- !query 1105 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as int) FROM t +-- !query 1105 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1105 output +false + + +-- !query 1106 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as int) FROM t +-- !query 1106 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1106 output +false + + +-- !query 1107 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as int) FROM t +-- !query 1107 schema +struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS INT) AS DECIMAL(10,0)))):boolean> +-- !query 1107 output +false + + +-- !query 1108 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as int) FROM t +-- !query 1108 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS INT) AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1108 output +false + + +-- !query 1109 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as bigint) FROM t +-- !query 1109 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1109 output +false + + +-- !query 1110 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as bigint) FROM t +-- !query 1110 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1110 output +false + + +-- !query 1111 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as bigint) FROM t +-- !query 1111 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) = CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1111 output +false + + +-- !query 1112 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as bigint) FROM t +-- !query 1112 schema +struct<(NOT (CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)))):boolean> +-- !query 1112 output +false + + +-- !query 1113 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as float) FROM t +-- !query 1113 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 1113 output +false + + +-- !query 1114 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as float) FROM t +-- !query 1114 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 1114 output +false + + +-- !query 1115 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as float) FROM t +-- !query 1115 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 1115 output +false + + +-- !query 1116 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as float) FROM t +-- !query 1116 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS FLOAT) AS DOUBLE))):boolean> +-- !query 1116 output +false + + +-- !query 1117 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as double) FROM t +-- !query 1117 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 1117 output +false + + +-- !query 1118 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as double) FROM t +-- !query 1118 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 1118 output +false + + +-- !query 1119 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as double) FROM t +-- !query 1119 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 1119 output +false + + +-- !query 1120 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as double) FROM t +-- !query 1120 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(1 AS DOUBLE))):boolean> +-- !query 1120 output +false + + +-- !query 1121 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1121 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1121 output +false + + +-- !query 1122 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1122 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DECIMAL(10,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(10,0)))):boolean> +-- !query 1122 output +false + + +-- !query 1123 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1123 schema +struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(1 AS DECIMAL(10,0)))):boolean> +-- !query 1123 output +false + + +-- !query 1124 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as decimal(10, 0)) FROM t +-- !query 1124 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DECIMAL(20,0)) = CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)))):boolean> +-- !query 1124 output +false + + +-- !query 1125 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as string) FROM t +-- !query 1125 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(3,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> +-- !query 1125 output +false + + +-- !query 1126 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as string) FROM t +-- !query 1126 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(5,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> +-- !query 1126 output +false + + +-- !query 1127 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as string) FROM t +-- !query 1127 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> +-- !query 1127 output +false + + +-- !query 1128 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as string) FROM t +-- !query 1128 schema +struct<(NOT (CAST(CAST(1 AS DECIMAL(20,0)) AS DOUBLE) = CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> +-- !query 1128 output +false + + +-- !query 1129 +SELECT cast(1 as decimal(3, 0)) <> cast('1' as binary) FROM t +-- !query 1129 schema +struct<> +-- !query 1129 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('1' AS BINARY))' (decimal(3,0) and binary).; line 1 pos 7 + + +-- !query 1130 +SELECT cast(1 as decimal(5, 0)) <> cast('1' as binary) FROM t +-- !query 1130 schema +struct<> +-- !query 1130 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('1' AS BINARY))' (decimal(5,0) and binary).; line 1 pos 7 + + +-- !query 1131 +SELECT cast(1 as decimal(10, 0)) <> cast('1' as binary) FROM t +-- !query 1131 schema +struct<> +-- !query 1131 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('1' AS BINARY))' (decimal(10,0) and binary).; line 1 pos 7 + + +-- !query 1132 +SELECT cast(1 as decimal(20, 0)) <> cast('1' as binary) FROM t +-- !query 1132 schema +struct<> +-- !query 1132 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('1' AS BINARY))' (decimal(20,0) and binary).; line 1 pos 7 + + +-- !query 1133 +SELECT cast(1 as decimal(3, 0)) <> cast(1 as boolean) FROM t +-- !query 1133 schema +struct<(NOT (CAST(1 AS DECIMAL(3,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(3,0)))):boolean> +-- !query 1133 output +false + + +-- !query 1134 +SELECT cast(1 as decimal(5, 0)) <> cast(1 as boolean) FROM t +-- !query 1134 schema +struct<(NOT (CAST(1 AS DECIMAL(5,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(5,0)))):boolean> +-- !query 1134 output +false + + +-- !query 1135 +SELECT cast(1 as decimal(10, 0)) <> cast(1 as boolean) FROM t +-- !query 1135 schema +struct<(NOT (CAST(1 AS DECIMAL(10,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(10,0)))):boolean> +-- !query 1135 output +false + + +-- !query 1136 +SELECT cast(1 as decimal(20, 0)) <> cast(1 as boolean) FROM t +-- !query 1136 schema +struct<(NOT (CAST(1 AS DECIMAL(20,0)) = CAST(CAST(1 AS BOOLEAN) AS DECIMAL(20,0)))):boolean> +-- !query 1136 output +false + + +-- !query 1137 +SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1137 schema +struct<> +-- !query 1137 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(3,0) and timestamp).; line 1 pos 7 + + +-- !query 1138 +SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1138 schema +struct<> +-- !query 1138 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(5,0) and timestamp).; line 1 pos 7 + + +-- !query 1139 +SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1139 schema +struct<> +-- !query 1139 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(10,0) and timestamp).; line 1 pos 7 + + +-- !query 1140 +SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00.0' as timestamp) FROM t +-- !query 1140 schema +struct<> +-- !query 1140 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00.0' AS TIMESTAMP))' (decimal(20,0) and timestamp).; line 1 pos 7 + + +-- !query 1141 +SELECT cast(1 as decimal(3, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1141 schema +struct<> +-- !query 1141 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(3,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(3,0) and date).; line 1 pos 7 + + +-- !query 1142 +SELECT cast(1 as decimal(5, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1142 schema +struct<> +-- !query 1142 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(5,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(5,0) and date).; line 1 pos 7 + + +-- !query 1143 +SELECT cast(1 as decimal(10, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1143 schema +struct<> +-- !query 1143 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(10,0) and date).; line 1 pos 7 + + +-- !query 1144 +SELECT cast(1 as decimal(20, 0)) <> cast('2017-12-11 09:30:00' as date) FROM t +-- !query 1144 schema +struct<> +-- !query 1144 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(20,0)) = CAST('2017-12-11 09:30:00' AS DATE))' (decimal(20,0) and date).; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out new file mode 100644 index 0000000000000..5dd257ba6a0bb --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -0,0 +1,206 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 25 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint)) FROM t +-- !query 1 schema +struct +-- !query 1 output +1 + + +-- !query 2 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint)) FROM t +-- !query 2 schema +struct +-- !query 2 output +1 + + +-- !query 3 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int)) FROM t +-- !query 3 schema +struct +-- !query 3 output +1 + + +-- !query 4 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint)) FROM t +-- !query 4 schema +struct +-- !query 4 output +1 + + +-- !query 5 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float)) FROM t +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double)) FROM t +-- !query 6 schema +struct +-- !query 6 output +1 + + +-- !query 7 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0))) FROM t +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string)) FROM t +-- !query 8 schema +struct +-- !query 8 output +1 + + +-- !query 9 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary)) FROM t +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean)) FROM t +-- !query 10 schema +struct +-- !query 10 output +1 + + +-- !query 11 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp)) FROM t +-- !query 11 schema +struct +-- !query 11 output +1 + + +-- !query 12 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date)) FROM t +-- !query 12 schema +struct +-- !query 12 output +1 + + +-- !query 13 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as tinyint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as smallint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as int) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as bigint) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 16 schema +struct +-- !query 16 output +1 + + +-- !query 17 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as float) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 17 schema +struct +-- !query 17 output +1 + + +-- !query 18 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as double) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 18 schema +struct +-- !query 18 output +1 + + +-- !query 19 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as decimal(10, 0)) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 19 schema +struct +-- !query 19 output +1 + + +-- !query 20 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'StringType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 + + +-- !query 21 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'BinaryType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 + + +-- !query 22 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'BooleanType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 + + +-- !query 23 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as timestamp) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 + + +-- !query 24 +SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00' as date) DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM t +-- !query 24 schema +struct +-- !query 24 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 73ad27e5bf8ce..a52e198eb9a8f 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -61,7 +61,7 @@ ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType does not match the expected data type 'IntegerType'.; line 1 pos 41 +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType' does not match the expected data type 'int'.; line 1 pos 41 -- !query 4 From fe65361b0579777c360dee1d7f633f28df0c6aeb Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 21 Dec 2017 09:22:08 -0800 Subject: [PATCH 478/712] [SPARK-22042][FOLLOW-UP][SQL] ReorderJoinPredicates can break when child's partitioning is not decided ## What changes were proposed in this pull request? This is a followup PR of https://github.com/apache/spark/pull/19257 where gatorsmile had left couple comments wrt code style. ## How was this patch tested? Doesn't change any functionality. Will depend on build to see if no checkstyle rules are violated. Author: Tejas Patil Closes #20041 from tejasapatil/followup_19257. --- .../exchange/EnsureRequirements.scala | 82 ++++++++++--------- .../spark/sql/sources/BucketedReadSuite.scala | 4 +- 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 82f0b9f5cd060..c8e236be28b42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } - /** - * When the physical operators are created for JOIN, the ordering of join keys is based on order - * in which the join keys appear in the user query. That might not match with the output - * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the - * partitioning of the join nodes' children. - */ - def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { - - def reorder(expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() + private def reorder( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() - expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) - (leftKeysBuffer, rightKeysBuffer) - } + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } - if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftExpressions, leftKeys) + private def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(rightExpressions, rightKeys) + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(leftKeys, rightKeys, rightExpressions, rightKeys) - case _ => (leftKeys, rightKeys) - } + case _ => (leftKeys, rightKeys) } - } else { - (leftKeys, rightKeys) } + } else { + (leftKeys, rightKeys) } + } + /** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ + private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan.transformUp { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9025859e91066..fb61fa716b946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -620,7 +620,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { |) ab |JOIN table2 c |ON ab.i = c.i - |""".stripMargin), + """.stripMargin), sql(""" |SELECT a.i, a.j, a.k, c.i, c.j, c.k |FROM bucketed_table a @@ -628,7 +628,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { |ON a.i = b.i |JOIN table2 c |ON a.i = c.i - |""".stripMargin)) + """.stripMargin)) } } } From 7beb375bf4e8400f830a7fc7ff414634dd6efc78 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 21 Dec 2017 15:37:55 -0800 Subject: [PATCH 479/712] [SPARK-22861][SQL] SQLAppStatusListener handles multi-job executions. When one execution has multiple jobs, we need to append to the set of stages, not replace them on every job. Added unit test and ran existing tests on jenkins Author: Imran Rashid Closes #20047 from squito/SPARK-22861. --- .../execution/ui/SQLAppStatusListener.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index aa78fa015dbef..2295b8dd5fe36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -87,7 +87,7 @@ class SQLAppStatusListener( } exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) - exec.stages = event.stageIds.toSet + exec.stages ++= event.stageIds.toSet update(exec) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 5ebbeb4a7cb40..7d84f45d36bee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -383,6 +383,49 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with assertJobs(statusStore.execution(executionId), failed = Seq(0)) } + test("handle one execution with multiple jobs") { + val statusStore = createStatusStore() + val listener = statusStore.listener.get + + val executionId = 0 + val df = createTestDataFrame + listener.onOtherEvent(SparkListenerSQLExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) + + var stageId = 0 + def twoStageJob(jobId: Int): Unit = { + val stages = Seq(stageId, stageId + 1).map { id => createStageInfo(id, 0)} + stageId += 2 + listener.onJobStart(SparkListenerJobStart( + jobId = jobId, + time = System.currentTimeMillis(), + stageInfos = stages, + createProperties(executionId))) + stages.foreach { s => + listener.onStageSubmitted(SparkListenerStageSubmitted(s)) + listener.onStageCompleted(SparkListenerStageCompleted(s)) + } + listener.onJobEnd(SparkListenerJobEnd( + jobId = jobId, + time = System.currentTimeMillis(), + JobSucceeded + )) + } + // submit two jobs with the same executionId + twoStageJob(0) + twoStageJob(1) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + + assertJobs(statusStore.execution(0), completed = 0 to 1) + assert(statusStore.execution(0).get.stages === (0 to 3).toSet) + } + test("SPARK-11126: no memory leak when running non SQL jobs") { val listener = spark.sharedState.statusStore.listener.get // At the beginning of this test case, there should be no live data in the listener. From 7ab165b7061d9acc26523227076056e94354d204 Mon Sep 17 00:00:00 2001 From: foxish Date: Thu, 21 Dec 2017 17:21:11 -0800 Subject: [PATCH 480/712] [SPARK-22648][K8S] Spark on Kubernetes - Documentation What changes were proposed in this pull request? This PR contains documentation on the usage of Kubernetes scheduler in Spark 2.3, and a shell script to make it easier to build docker images required to use the integration. The changes detailed here are covered by https://github.com/apache/spark/pull/19717 and https://github.com/apache/spark/pull/19468 which have merged already. How was this patch tested? The script has been in use for releases on our fork. Rest is documentation. cc rxin mateiz (shepherd) k8s-big-data SIG members & contributors: foxish ash211 mccheah liyinan926 erikerlandson ssuchter varunkatta kimoonkim tnachen ifilonenko reviewers: vanzin felixcheung jiangxb1987 mridulm TODO: - [x] Add dockerfiles directory to built distribution. (https://github.com/apache/spark/pull/20007) - [x] Change references to docker to instead say "container" (https://github.com/apache/spark/pull/19995) - [x] Update configuration table. - [x] Modify spark.kubernetes.allocation.batch.delay to take time instead of int (#20032) Author: foxish Closes #19946 from foxish/update-k8s-docs. --- docs/_layouts/global.html | 1 + docs/building-spark.md | 6 +- docs/cluster-overview.md | 7 +- docs/configuration.md | 2 + docs/img/k8s-cluster-mode.png | Bin 0 -> 55538 bytes docs/index.md | 3 +- docs/running-on-kubernetes.md | 578 +++++++++++++++++++++++++++++++ docs/running-on-yarn.md | 4 +- docs/submitting-applications.md | 16 + sbin/build-push-docker-images.sh | 68 ++++ 10 files changed, 677 insertions(+), 8 deletions(-) create mode 100644 docs/img/k8s-cluster-mode.png create mode 100644 docs/running-on-kubernetes.md create mode 100755 sbin/build-push-docker-images.sh diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 67b05ecf7a858..e5af5ae4561c7 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -99,6 +99,7 @@
  • Spark Standalone
  • Mesos
  • YARN
  • +
  • Kubernetes
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index 98f7df155456f..c391255a91596 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -49,7 +49,7 @@ To create a Spark distribution like those distributed by the to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: - ./dev/make-distribution.sh --name custom-spark --pip --r --tgz -Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn + ./dev/make-distribution.sh --name custom-spark --pip --r --tgz -Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn -Pkubernetes This will build Spark distribution along with Python pip and R packages. For more information on usage, run `./dev/make-distribution.sh --help` @@ -90,6 +90,10 @@ like ZooKeeper and Hadoop itself. ## Building with Mesos support ./build/mvn -Pmesos -DskipTests clean package + +## Building with Kubernetes support + + ./build/mvn -Pkubernetes -DskipTests clean package ## Building with Kafka 0.8 support diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index c42bb4bb8377e..658e67f99dd71 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,11 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark) -- In addition to the above, -there is experimental support for Kubernetes. Kubernetes is an open-source platform -for providing container-centric infrastructure. Kubernetes support is being actively -developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. -For documentation, refer to that project's README. +* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) +is an open-source platform that provides container-centric infrastructure. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/configuration.md b/docs/configuration.md index d70bac134808f..1189aea2aa71f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2376,6 +2376,8 @@ can be found on the pages for each mode: #### [Mesos](running-on-mesos.html#configuration) +#### [Kubernetes](running-on-kubernetes.html#configuration) + #### [Standalone Mode](spark-standalone.html#cluster-launch-scripts) # Environment Variables diff --git a/docs/img/k8s-cluster-mode.png b/docs/img/k8s-cluster-mode.png new file mode 100644 index 0000000000000000000000000000000000000000..12a6288ed58235b497d838350d17b482cff3b609 GIT binary patch literal 55538 zcmbrlbyQSg+c!EWQVI%4Ny$iecPI>9BHi8HIfN)3(hbrvbTch|S|dEW28 zbKbMoS*+Ou!=AnGxPEb6_q~HZ%Za0-y+Z?mKdI zAT167J^lUp*^(ave1d8xq3#F*y~O_eLI5SF;sYNdJ4wolBCo#0K;i+b7sX0~K<`14 zkWWf(^ZQGtUYZ+oNXMF*)8$5emCftcR4zn(_Ng2Dh$wF)@cqg)sbOJXG&FuyjR*!nbfm>HZvypnG(ZPc^_a}w*mY@TzCtbl_Q+{4v@hC=t zErQ~Y?>_>K0Aar+%>VoT!})(dsv`sE=so-M=Pd}N^@_7pJ2VDlI9drS!wv4rzI|IT zg8%~g2EY1u+W-6YZUat$`{}9uVNXsqp9nI0x;&6;f}&7BXyRUoib9lBv=+^oS7L`X zYh|dhoZr1N?&{(u?BxZ4>@wRV>a|Shxj92Qs22GGW)xK2>Wl1cy1z zT6v@~o_i;;QU=h3_iXi0k{`eL+E_$Ol0iz7;K`p}IxJTxlNT1eYR4_s(s~y|YiMBP zz=P?WG(^m+ZP%a%R_N)$#;oEz_njk>ld&o$XB}TV!N2ebog>otDV!S;JK)ZrK4if` z@al5e(hQuPl@%)`m|KdC<%@uwZ zz>12vJ=oo<+3@ItzR>zx@}UQacwRN84$pr3Mj$FW$RT{7A*&Z2TxGgPS2rECzyN)h zv4~Gh=B44{s)x`=y=BmT=fUrtVLO1QORZeAzuzcWDq{Z~BGAdT7v=VB@$?^JzQm&MexWmFPD4X}AK!Db zL=9}nW$?N7B^XUlW(a5SD#i;>H8 zo?XPVX~{&fr|0a&*Ymy@LQ6|)6*|lf99%J)IGY66=DK7CTnK2qRADNwq1evj17vb+ znFJe89CA=n!k9mC;K1^N^;15q>dZ+t{3CTLyR*d?kH$G8i-N-7;84S_a)#C=b4JnP zqJ`L4=hoIsx>ko_1Tv%g5wesCn)LC9pG@Dg=F4D)KuGGE^0K}LDcRc#Mh2Rv8BXps zwIOl_b|Mm6s;X)$$9<$8OlGBEp~xLDDuNZE;id>#;wB7S+zn$OhunrF|5d7w;KRW1uCjc`0&_s z8WxMzdiLbf-f27&nMMHE&g6?ux^ONy%ir@-t*J5K^H+tk z*Lg8v`&QuSek00kh0@Q=>K2F56oW_|K`4|g-wsO!ppBlR`cj4@iQ(Nl2@qh@}kgG@X`r9_38bY$%ePs=4qp{xv6WMOG@Xy~b~^!DsMpHPPTFCzYle%fA> z8de<9HR8Xa92WL1hkWysj92qclhT+b`sc_2|D^ZDuIvdTrqSpVE;(F-=*fp)3BH<# zVQJX5^9S7@Hu-$iO@$v!t+jI>Ts|z+u^nhb4D4OJy&c5FuA&0Jn*L(R!3=hTPm#P3 zA;FYQ=T)Z0qEVv#E;Z$~DMM;_#@v^8CyfGAdpB<;`IpO?^^?N7ONWPZ&d-Z!6vI5A zW`TLk(Y?JS3855Nn}X8iL&YOohOJJb2KM>w<+gj%nEWexoVzE(s%@Uux9c$^mIjTy ze4(!UKj_}7H#F>CyG5y`#>e4Up+hg7p$WSHvZ1L3%in5gFuK$70s^?hDfxtS$#JR3 zDW+7FFXFkaabKRT>^bjsJBy?ubX@;!5^H}*sC;akv$IR3TJhHgNB=lvgtm+ukNAEg z$lDT?fwm^}w_-8L{cC3cXWy>(nV!Y8e}s6XMFeHV_HX}ql#NnQXs3VSv|+u0-y0fEe_ zm*0Dy57|9yuWG%%bai1yDN!0GuON!TrU1%#gHoN!<9PlH!G!GSp33;}mc^7lu{4qJ zdwT+N!EK2>;>EvCJoEE~Z_&}wsJqP>Kg+b6-=9NvIik#pxV+3ccLHkwpUUMclH;;5 zXf!7L{9MSWS<_F7!fdHoqhVr{JTW!ZXKpSWWd=koW%k(CMF<5}rbIP#t$S5T+O1`I zH-?I8Yol;&jj7?*B4w(7)!~D2<>nEVe+?{jM@)y=m2sll_^vygm^V!)A5hM@o10)7 z8k>YrrZ*DCmD$J`0Dn^FIslCt{3Xm!64}SHbf*uf$L;Pe0GmP?mz^SI#z6laTDY3# z!UMN|NrDNJ)^_=Vs6j=~0Y(cvmcuwiIG5Gc<+fQTZ+LL;f;Bv*QHFApLMDUF)Imq* zo5cBsI4NKnN;Irf#bS7jrMw&}Cu~f0ItouuN4>58*_3U+E-OLzb85`4-(4*5(V_5g zK%=}fB9mTbDZ|h&`O%E)=0aTe?Re8Gqf@RT|fbY z!#rAg)(AGorj?Y4m>6mntc0T^2(Uk>QuC9@+kGZYm{8YI*}S?@qmP%M$5?KmU23W; zE$Ux$VH0^PWz1@+iZ?;DU{lqQDpvu?3JVSn3vp=DitCBs;No=g*#)S^l8&xLhbgoR zD+J>A@n?2^=?=tba*r;O_$7&$r={h|a9So|Vpz8?+~;x-<-2xT8uHb?e#G>eO9r?T6}*)a!QKf*EC=s_vV;o?2biDLao@iR1Tvzdj~kv3@_l z3rU<=UPwZXpIAd0Z1wj*g{K&+>Mph1jG?q$9Y2wRK&{9+Wv;$gDYuCV?MG*G;$~ul z%+N5T$)2by7iRx=nr6Ob1@IZc>N8HMEk~Jl0eV1bKLYIf|9spgyiI)$(t3NF&Vjum zyqYSsx^4DL?JMvN=BQN?T;Tov9&q?yj{hVb^#9?g|wYb5kV{bqi>i7^Dai zM;rWU+K(ltW6|sPtnV3cw+{y9`>_f`y!C;=*n+SL{$2V1cSC;J>7_f?ySWoHIw~vv zI_fFJPeDX0Waq#I8UncNfDS`vrIw30Mca--{3CV>XHzIHn?&1-!@fZ{2KKiR)m#Chl1_oy1$It^FMSYaKC2uYE(qmsiV@j z*c^4Sr7!sb8HlApZNWLLo{f3wXqVpeqFNRSkr-d|Q+~rlPS_i{(BN)Vzt!2>KR(k! zC;dNMeV)?uL*YxK;;TpSGj_RCH$M;JkE4DpcuqFM!@TpQJv@FabS-Ad^nyrR2!M0= zgdBkH^qb#{y*-ihYFL>Y70?D0Uf-Y3S5t8pqMR41#bJvs`K%OVrtS*meFG=qO~KUe zVsd>xZId(Um~-dAPYM(c#5Sl_cs zuid0xq%zoc*Vx^Ovx)lve3dI+C6z}j%Jl-Q-b`B$wG@ROELWy(I3{fh0wX6b1TQQT z=)^;v2+5Oea-m5WF|6gkUR6E4#%IiMJ@zmAJ$Cy2ssw`p8JN2_PH&v_QeUL`z)v(C zr9E$(&A-%CmR_69HI}H}-==?^N_aDNg5K2V)SXu5MI8|jqsnC`ExUpH`)Ij6q03mD zGzuKSADn=mtalKpd&%V2maMmg8}HRFw_J|MEY6)xQzh@j)QA@aNqFYa((jo?DwROD z<;U?itdH(Xdki(7e$(NkgenufyU2;Hl&@8>p+nVf+A5J}#lxLE5CVuR$n0@xh8SKB ztEe_;hr$}3EvMHBp4@qblBi)XwgNYs;(I;Ryn#kJWc9Rn2X0_D9##_1H{N^sM59@E zQR{QBL!N*&*8aLRntDTq%>tb4=-vIyh!})O41|z8EvJvAnOW<_0X%Q8p8{OBbW^73 zQ*#$yBQs5wC1<@*ZJAJdy#$qE814~r_a4#AFk)a-l9vK5U2_AGWyZ%;m}|7w6}pQo zvPS+|7gH@>q;mh*YCc2H!OqSm31;6e80UcLddOtOx(b!=YUcudeW{}B(N%V2g0bIu z_rzqCk}O~>?c*@8&sVxTS^Hh3g-*5$yg8l2AQ-gzxR24Z+|C!r2B9l>5n7#@~BOJd1=PQV!>nY@^&KP{|5FqAdrX>@T&27f^_Nw!6TQW0qE6&L=S?av;cO_}`N4A;Orv=Zc z>Xwoke!f2s?x;?)2paz#An4pNIj9%&E`XJzwQeb2i^r#jj+6vmugL~wjY$Ph^~cUN zKqfU5#355nQZUyzq&lFTI(0pxqaDp3jivZgI@=eX|D%sj$lWaxHK(PoADxan%j6FD zoxD0r*9cwgZ<@23N>OV|c1Xmp!DJJVZVP7U#oz*mXMf`NBByYZc_U~Er_wZv+VzqK!iOS{KIviQpMF}mmV zO&qyU9PXl8T%qtTRfwX&CM_GJt$bIlbFZ;Q< z*&*x+@Gk-)@wMQ-&sG({HH4d`F{~8&(+sMN-p7<+*1B|S$Lz8x4fHK6G3W^kSc-O# zEGtnfouNasd83j961nyumUGlI2$rR_kxFISF4Cu9UWD14M(Ed2VEA*ju ziun)9nXWyzaTR)v5{xB_Dh>t{f&iPgryrxbC+XA$x|4kZz&bpS{PfnU?D;>MN8x4u zgp%`T6Pu$6ejU^RU32!NyDZZxh-M3H(il$*>2}hNS@U$D?>+T|5Rnu5IwYNb-$FN&HEhFdFNG1`pw-+ zs}xUab6hB_HXi?7@|g-F8e@jb(3e(;`kkFRO}gDA3V=7`N)~pnZf?zl5IUMG7qPtJQWQ5Lx~ z8?`C)0LthxpeRSk|K-(0GoUUOUdF#{SKfSKls0pKDIuEl676f9XHXCo*78MrMt|G9 z)~CxE~2qxop=^LC!$O3%1t~iXxGW z((Pf(MKX7a24>SCSmc5%Zd2cQ%h?TZm9^=7Wox08TD~}%lNhEX?G~6)@edZuGVe{l za_x}AW=h+QKgYk$7e5FWau;_{8h&c`Ptk!bD_hu%VJm3KX=o^TGh1n&RxR9M7OW6B z1Jz02Dv`{ zf<1NSLjK#a?e+*&p}~e`YA2-Niadj}I-NSx6MwD~(T?#NxD)*6_OP<}%n|-dm^O3K zaTMY|TV*JHu*?cAxvbQ^XRbFoTG#^!Ng~Ub;s`odXuTHtjb~tMLx)GG-VaD2HS0=R zuL~9q-_mU;GG-*y)H7EP0y^5mweO%|RTP%fAf!cb873Ec(A#Joo{OVXr94+O59yW7#J^s) zBG*M;9gwdTqm?kgK+t*nz^g6ev7)B*Znu(qUV9%;H8+(RW`$i!Oy#y+w-;-c8gVY) zZ@Eg(c*kgQ=jvTh06_!*JF%f$??RW<(2*M`N=~r(2({{~mam7mJ*AX%IxPr!=O+m! zNy?E5oo>m>x5csY)zx8QM0&US2PMX4CkgUMVS9RY9#RN3P!v-*f6*TF<>BK^k4{80kJoTRS4*&afTIo9!tPSFc_TD{%s(n_&;XGE*`!` zI(a#bp%^(A6!aC{7`#4St{i3Y;_R2TD^}FZ21{Lapv@1hMC?V%!LbpSH)CTLd=aH^ z%*N)Ltk#+=!E={5@^>*_WP70Z+wZqQsgX=r7#5!D%^^XJP66c`Hcc@!Axv1TQ$m^O zzKu>|l>N+Pg~xdt-!0=`ukQe-Od!kU45v5#alQJ^SVkv!c~SL5ty#z6f4tJzzTehQ66e>R8AN!U!5sqLFbaI0ahC1-K9#aro(tV{liusJZO z-b-hA)$s<@P%gBjM}>O{r&Qf$j35;L&}=fIZ(qHb z#a79)Pv;Gl*H93QF|qO5NC04P54Uv##J6%g3@{X#X?+O*L^>yx?+>?fNuk(e3CWu0 zUQx$@l45akzQ^PsqUWmH2HLrOA_k-q@v)^3c}@qLD+lG><7mcG9cm;GH{f+1u;Hgqv!mQV$2?RE8$`?cM}f4(s-$4+)IIYB zXGVcJuSdVNYyTf*bI}`th+;#Vu@sH-{>kg*?xiu3!qdC^K?T6GlxaWyq7Z6=ULL5f zCSO_um4)02MX@rkdv5=ium6hD`&e2LHuJtgAa6sel2c1?0Ds9N|2EYGF*gtioJIW* z8{(gfi8&EVr35h5;mb~DS;yPi_uLY%G1F%2HK}p7bgK~Y&cBNM*B#|xPgY{4@mM3>nRgB(egDqzb(|%VyE5TjdvTi-S#pIXhwv}( zK{9&RZ~c!1wr5jiMpw({q$y2_`Lers2owc{9A14w{#WT3wkwz~!ycSn5F#|7Pms4&4Q@@~bR@>Tw@PjR;x zsfeINiBV1hphEMXHOg+P#m;9S-_W3$qq7%YI&C9+SMK+MW$+$j$Dp8214lJ0lE48| zN1lCz2n76XCk=Frs+l9gNtR=7{u&9!)lfU0f7c0sO<5FhaMUfxs?>FE+TT8Z^mYwi zIo%L06nhu*K3GNit(MB`tB<_e@l2M|gN%DWUen8x)X}>K_;z-32>-2i;S<^Zf}3zC z5!Qs8qz&y6?H^*4E$DMpYoJ4Z8&|B6kHy4Enpbj@Z0#p(CY4^6$&(&mH+S;%dMLt{ zTMdm(6k&+|c7q^vKI-ARO zAn*$jgY?fdKd_?atl63lTC7d%7iN?XT4|v+sgU4-z<-xYpbYiaoe0JVbdT-7x^B)7 zn^)}C3G`4!%T=Ys(J7@#Bu)76(mXWVG4ZlS#Y^MqAU^NIlICT-S9;m5bZJWUn;JgA z7=b_nY`l$^TnSt*JB*VEC_K%94rLy0yULa6$%L>E>ofT{+8HWJ@TtqiD0!jM1-x3n z&b`OV+V#9X>91G4e?1wh{hnZ&HLfYkzd{Eu6#k+`{F<2jxrmU5eBo@nuyg{6 zQPf6qmZv=&d1rWT|7dS@0208Dp!!*NujR76*Rj<~5lYY3Gm}WtE;Di?L)<&s4-+Xu zb?fd?bD5j{hHVa5v=+yFJ=K^oxLcMd(kT6xR$n99aUwCw(Mm=QxDjE43?`}wxLGUH_+FRxrY{o}#+*QLv+TH75Mfen3G@{86G)Bo8E5L;vy7v;PH zcBpT-NGVEI?3Rm!M^>>Q@^gBAk{a*dr>}v8RUu%1Oc}a%Uh*S*mBjYvk@;-et-^qt zLMuAa$ME*PJg>j!!nv2Xm)EXhbYb5%o~US_84ELo`nwnHsq}Mi1i2v)cM$z=8~qB+ zhwAKBecR3?WuBPQ3|0bDr0yC;v4i1mR@;@1r&Gz$uof1mcO(p$XrunMDa=#9*t6f~ zdqR8<)@I`8Y$j#pnGeLh69@j0`| z746oE=YD{(<+f8S`?mEJ$;WlAhD5%vmJMJLH?4Vah9guh2vDrtX#c8ATtwj&!Igc~ ze7MX&w8BD*<*aK}_!8pyY;DXQ`Y?;i3vrSZ_9nHoCfQQVeM-5QGh9fCywaRZ4m-Zh zqzoldB~nMGtXBWj1;(LxS#JBlOE1kLG(|)J}+ji{9gU_-$dkfn$Dy#A-W$`Mn3N*je`w^n%B2pZmb4G+|ZOMzv3GdkE(YGSH?glvMq4hXVkP_Ltf}rzj za!{xTF%Zn#5c&=`-fXI8`JnD*0HV{Ai)%&3sP0c1r$J#%jg-!pBdrfJOlM_=@upr} z*UJt==h?*hofcA!1JPZMv#I^VnTsD_`;)b^VGk~RBN5~LTz3(oCk;^B+7ceGR(Vv? zC)udKyJ3VE7A9xjc(Jwy1v%jKX{zuOYbAoiZ>(-&Y}&3IWZMZ~-b2%P9)m0c922BN znimTqq|G(Qe&@ZjvrmhdOe&WGkFt@iR{T!3!>57~$9+4nqv6(KxP45pnc`YzIR#Br zh1=^e`f7w1AYYc8x$!Zpo$c0Cf^TAEv6r=L1=L=9klGGxyrWY+!gWcNhwj*Rd-R0@G%g`NQu&xen zG`FMa)%g{1s;LF9Y+oKyW%bLIw40+YnK({hf^YZF>qNF=oI7US1}-yyr(t(@Po-q+ zwXW@76E&m%H7|M{yoO^JwH6H~9>*sE-Tq7d7;6)^e@5GFU&+0kFo#W6Uf=uN8qC)(XH6M54l z2^ozyL!%hWDI>8L3cUdV&)a)utRL5+B0RuFoOO$p528Nre;GSY=Zw6z`5Ey8o?i1U zXrCni1Xu*|gK^f;-`@wriM-n26)CVG2FqH(Cx2ol*Rru-9J)=a^*;j4QvVDcBY)%w z?7-s49%o{Kzu*Y>_Yj$DJEE5N@yS#l+Z?cKNkC{TP!WeF@SFo+gM4>7K*DfSSdL;c z&Zh&t-lqzV>}{kh3)DbGzpNI;&V6Sg$n4RwwgF=6M6}cI!8%SrFf@^_84TrY)CDp@ z&@afe(Ud1Lr%+hs6`9NSy=1SD*5z5o{vGnf%-@NT98a`z|8u@w%}2%5%?3ja(vj50?8@E*q=O@5+~7keZR1unYB zBYR)AE)|9=zum8_nduv8tk0deX;^h5UpC(WQ5Re=WW(A>4>`}+C5E-SUv646VdQ_w z7YE`bOc`!F8XWOsDX|K1$VrTF7ntpp4omi{dFBQHY8%*u(?+S*rffotThYJoNm8hE zGgRKrha?5L)Wrrq-T;Ep;tfBXK9y0w7nnZ#BMd5qmC&^6RSmPwhf8^kOj`H|^c+*4=>j^MM|&SN9;BKCSaL8r5%%bCYWC;$y_(XO61e)6DYxi=Q92IB-nz^+(zp zIow5kNk#WPidkA1vGaO;ba#iB@RS_%?5!SVsCqp;nXS68PUJ#QD2^~jSr3V(Bo?q< zp{^g}Ivn|u^tUZg!Z4d&t>l@>9X7V{3!j8>@)R4;R1RMuSjXr;`%}2)w#)4TfsiiP z899^+6Eyy|tx0u1Z=G;0u~$(|zw-Z>xn8)Yc2DAAQGV@}&Ktq715-uO1OuBckWPG9 z!C7#BQ|TNnS2NfOoe!?R@2~T-?w^4jR8(sPuVOV9@EboZ zy!2tVor6Fs=~Tqv-mc8b-2IU|Pmjsjm8n7XMfex~LU=|RapPtukrBqmj}84pq}T~a2}ZJq6r@?q}V zEDG&|RmSne<1LQ{+Z8v&5se(G-S^jHy#p%OgL#b{barFNu+NckKp zS2|Lh;Z}Tr%=eIDEYa_DajALvjadEZUGFH~|CXO&TfY_W%XDu@2j&vGG;ITe zU8Q4G;fyjPVZB3FNwP9EQsdILXOs5cur3W332qwASlypDXE5}u<#$$-gMZGjpROGs zHiDS$P9?R^?*P=PX2vpV^&1i6W-w`c(O^0s%r2937PPY@QNNApQ%u3$8i$`9Z~idXXDcS+($ zSDFk(o}?;^c67QiL{X}GLJri$?A~7Bv!54pa9$XlI#C2O-^|iD1agWqb3AE}ka0Aa zebw}CcNP<~@JXjfnR|r0zx1br zN9I9sCSDfu*7@KafI_<_LejefQ)SL&A>$np$$b=&_2ylRfM>!ixJyU zhAs!RrTx5%SNx_zFNh4>HU%C`rXNfmv>H^VezB93NdbVX=KiE^%e&M1M#{dkWIqT+ za!h!j)O!58bQ1AFPybhA4?griJ8Ve@P;HY4nPGX^*9;SlK6Qt2TlSgyUn`olpTlSO zu2@9K*dD#kCrk$f4_(XT_d=UY)5r{M3eaO_wz_~c0Rf*#a;&3X_rCThz5}vuZq7>> z*6R6DAbi7ENX6uOMY~JZwVdKh7kZtTwrkAJz{t66E79@IEG8HX#Ct{Cwe>dR;OTqK zIrp2ssn1&l^pU;8;aXmWHExl>bYLd6-Tv9P?V05U1Ju~j|7xum*8V*pxCa~mw(b$M zFtl$AtM=3em+D=M!)jcBJPW;zC~)MvbR{(I8`8D?QKrGrW?)1 z4&#iZ(j^H*HuxDY-{MsDAC+TA!Acteni{zy>M01Y!$fiUKq}CdVLuTB)X~Q!vX1&% z_etY41Mbek@cjPM9vP4A`+Mw7lNe@CsmPSzQ&W7wdoIl2ybR*%$r9`B)y{-Y)N3yi zZP?>cnRDqA^+g70#Ro9J8bB-Urj$cZUtXWQ%bPoG(7t{cd_P-bsW}))YJ75|FtY@K z)QV8FFQrW>e~oq0Y3^26#kL9?Btj~smvM`F7G<16xVxNQLoSW$BcZEkXv3gn%+G&O zgh_udUv)IfGg#uX#%{c;WZeH(2Q%};e|N3`RRNIy+v6KGqKNxl%jYCqx-o-Jmxrro zxq5kNb3zrz=Qg#e>MK5G>|#tw$DkrL;N(9SI`Y{p=|lS6Db_LJsE9WoO{Ej_~-R z3s18e_LTN{cR#UdFX+C~(oT|tua2yatXIw}qzS8ERrw_4u|MLCb9^(e<0kgdi@k)r zQI$X((#xc`EB#DRcR}|FUn3Gr>PI?-#0c8s4z-ZUe06%z;(T|!X?oaQm#zCQw~0n@ zmHxoe6F%AIPzpFfdw4y4Fif_`#CQD&Q1bCx54x(s#xu> z6gE5z^6gpv;hBD}7S)#QF5Kk`1)A7pj+;bM@MNAvqoBqd@+Y0%zNi$gsJ)Z|FxbCc(k_q50bDuL~`&BD&CGHVw?dMCEch!*s zzaRuAY>}mL8L8b-eGe+3uvW_I;ckuQ{KNP-fZx1iS?%j>to_&4Pgc-r@`&U$Uth3X zeh5sg5rlEj6RKYxRuj|U$TO{5MhEvUjlZR5bF&hNM`d0Pzpnl)lmXxt-q19yX?1}4a`YvB<>?;0hlm2yzA?!(YNW5kE3h_+8J*Wad z`yH+0J}!;i3(%_&C(u1U< zhk=@m77C)8iHH-_-DWPrc!NT;BwrD|#Gac0#9VLtfLXFSTP)zvm6zYh)vnJCO(#9@ zu>Oq03|Q6WrLwCrqOI&kXo9VW#peU(diq+Rhzt0p6%yQ6-{B)HuGHbY&iRbfYl3Cj z&f?NvDPbP3NUF*PCNDckgQ&UhSwLZ!Us@A;Rz1XQh*m!kU(mIBY##3>%oZOWvmU>R|npPQ!)FluxOb*TJ|U=r|1yaG9*|$$5#Ce4qF~GoQ^u{ z+C{n1J5vmi3Ezt-er`zj=dJ3Ui>mwUsk=oTaP&nBG+#~tsR{!WW=vB#8#TnMUzN;8 z%U2_kzTCc29lI@e!F@QqUhuo!B(3^ULmtp~FVexwkM`lV;@_{-vPZT^2dRWLwS>hO zfaQI$K2t3^;##B6!-ksLGIJO{cDD8=2r{_=`VcS zB_qNJyvaiUIH=|h9Rnv|FslXB_WC8`$oAgv{rY~nff+f5i!XQ9$iv$~oHPX2I;-eaJbqy_&$)Zu@@Bz33Ksp-RFF%+X|FwKcTmM3Dma= z?dn7H{ghNBMQnl!o!Kk+unv{iTb`+)iJk!8-Ktqu5ZsZJB}HY+w0{?y_Iqg#p>@U& z;AEde@NMr?&P8`5`CPSNV|U51*(IN+!%-bgx@{{sX5r}A_MKe32s>x2eiY7j-ud+` zOeq%wQ=M;vge#-R^|gz*DQHIjZG#<}9=#7$#jGQCo|Pp(f2y#}{H1UzwG=T~n}DUH zPMuUNT@2b*&jvse_)19=*%LMr88d4|46@>_f{tpYTr9dEUaff@ZFR{kY0 z=ox~FX%jD{0^hMVH;|F5t62#z5^&(+rUZEq$>ZwL_kJNKvEVayQ4iGW$h>^F2~@ov zv!#9Jp#DCL-fn#Xx>ua~Dy~4>`ouq0*R^U3Ejs>@zfL&M0qjk5utCD*l- z(eN3a+fef0qMjPk>3tHtQT?h}$3Io}P@n`aT*>?V=;$iUDoGp3cy{n7bl@`tQ-Rpt zjz=swv>=I_2$zln?F^%1HjXipwD2}^E7E&@?UY{qej1jp9>>~rw4_Roci6Bg4-Aus z7U6r3e1|8_vhvWEQ?L#hF=C9ud8J`vjh?VKn z1F6|GXx9Hca)g<`#5xnpNu4r3^U%zFEA>=$p3kQ$=qWypmHYW!t%dCEJyY5B0rG(N za!(>JeU`3Vo6Y;ccH;`uZuxHGme0rT$PkE%>B@vi`WW@xC@x{zDF5+6f=NyZUD;v> z)*|t*x7Z(S*`z;0q{pM}Y4yUSe~6iiQA|5wSl$)^>UOE=$R@yReS7bzo9n4*C6?B? zesMPmcDHpcZfw$0w3ELyekxcTzcO*e%6nDiL+}cD*V)(135ZO}7TP!X5Vl$AaYc8l zOMBchtWRD!2c{<}K_9TuHZ8F1IQ8~5li8t(u9e=JY!oT1wtUt{KUj0j&XjzzvcAloycDDzoz6>#X~^GEPqq=tD4!;T3J)Y8^9k2CBX$mRR20A4s zm>E0|=N`l&q4Oi&ZI!bLxx=x?EUlg21Iyhj2rc7yUy&u;U3WB)5;4G4TyHJo^$0bY zX;#8fuiqlHy{Rw-?IW;O&X!o|FjPPh{=C_nPc7batm45Je$$=FudLCuq-1PHG+LLe zn`JkwFo%cIvNfMu@yBJur`fO2w_!6yh;c^d$UdxHuzHTqcJ|bia;{mk=D~-<6U~tI zj>OAd-#{?W|06K?BW5qXyz~Z@cM;cyBhm7Qq~-30(z&37?^;?1DgVqDBT-cK|Hl zB4ep(tmTj_-M_R?yGg{oSbU)YDCN zwpvQ!_+AMPj7Lkw{A}L?EC(g_vxt|7SgQE%Z>vLDR5ui8Z!Km^f>R#KqM16V;GY220kr zAofj7e)>>JHO}e}GC#zaBnAl!)#rfc7{mqvD3hCnsYw&J3_Wi8bpm@0zy{d4KjcNd z&3`fdq96hxM}41Vpm%{qx=eEHTXR>@dX8TxGEcCE!OEyx2lV4!>5G#tjL|+;M*i#s zM+91%yl|uclaq7^Q)sqYRmIq)$wC}=4{o5TRMVm~BF1+-k3Uf;RLnhcYqH5GIFm6t$F_YPDzJlkrNwZF$y@77hK)3I|{=s*jLp^o`r$eN91{0#N) zE+#I3J>KX^P74UNisr%>kGe_iX#CtAz zWtTp6vwTHq;*}gQp)=?TnX~6fW0cl5b+mM>9(U3t6~SP2^Hy)%$+YNtNm83MWI!d6 zHL-z8!^*LF1gMO?p%R_5Qt*6K(d12IT>gmq~mlKpZglJRIzpZKi z&t8DPtHdutae8}&Jgnft3EGoa@yW&SIPVtixD}q6(=uGv;vqK5=Ht(M&m*ka?HH;v9w{E|RbQQ3>B}Ejm14!py}l0xQ!*EX*Cq}_2bV#cYo7@DwBe$Mkd#V4>eS{7#8%T0IXQ_ zfv!yUnO=L%1ie)rTmA%@LwLQMoK2p5Z@<3h(6Aex2CU`kuq$n{PnldcsgJMI;?gny z_$dO>YE2o~yWxuK!gnqzF0Ug7VANfpP%;@TUq)Cs9)>}MS%}h!6p?l9n-Roo<&uhq ze+T$S=ee-yiNY;G$9h4|{J=<*;K7{Fb?D67Z!iP**6cje?%NkJm#z*0Q)M{dRyaup zjvNt(jDsqNsiUJag802tjqvxpbmW}X{W1@IZ8?-XPo)?XJ>3T-uW4ytSe7^aJOvydnhS7A-VL*;L&SV@P3zpF4?Hn}4fpgEhlxy3 zEAN7Si4JePQ6@>}$r3+9SUk0EXO~;_Xg--D# z09k9BJ^51;dX`RV-e9zrm(z6`TmA+uL-R#R0goooLS9GJVIsApAV?`BYdk%<30(b7CaYlihgApE|c zit4<4?Az3a3zca7w6cFf%CZ zt%&A+-*ZP1ak>a;fiz`1xA%2PvsP^q76g+MFG7vc^DwJM1`2t-bVR`Tc%wR|9q$gL z(hyYcpMlP$;i+n_{7lsblq}g|F&v{wrdbB_*#`6Bjk9^lGMcLEA6-^orkRpLxKkfw z1?(T01nQ>V;_PD}{dJ!$O`g@X&WK(@w9X*CoNCS-cBND_i%vg@K%KtejDhL(_S7)rWZx+J6vPYI0`HTIQtc;^Ey~?7MwQt5z z;tQK7yW06D+8e+u^hFZsR!zy?zaq3Bhw-an5xer~8;zF2Q{%Wcq5Uq8E6pD)HkTqs z>%FxPbzvEsVqVG2z@0v^Akp1$Y>k6_9wus)1EqySmK3Jnz;#sKv}CG>SwKg2(}WHJ zK27jbZG|cAVZpd)%%?X;Rs^j2*2a5W?0rDTjRI$4@j~yLo%^coB8a$mj=HJ^XySI zsU^ehgxo446@_`HiW;fG7nML{WTE*4c>SNv%7~%scOfx+Od77+!z@Jex546!kROqw zm2z34*wStxb>!AI3Vo8K?f(RP*ZDsAS844Ju;@|hR7%99lQ}joSrC^ZUAf!M1F~;; zcuAycX`5t{QAe^&t5k6QL+WY%^W8A5xwNvLjf8#MS+c;b?4Y5q_s+WFGK``u(2Pz?|V3R5Pb}XM$0w@95yL|vdVRd2v)q) zk@qlj@v*$u4GLBK2hA0#Z2mCnMDO$rKnzT6U+!qZ67|v?j`iQ4<=Hj#6yE>KUCpS~ zv~R7O?7o6&y(rGq{=34S zAF>*Ry1)xGES2tVF|!{BH@>S$w>kg1Tk8+h!d2SywyP1Z5Kc25AsvhYP8sxtx7VO`bE(!zdn^DgNw z^|7yS%{jJ^j}@->7sO*Ddp7{#KK$oF^R~cJiYq2~u%fMdjs_`5^%EN>6>km-7iok< zk`Y#Xmb}DSNjWuDg&$nZy7jpSBU+Y3k4gTq=i&heeF?@GTGp-{(Od%m(e6juCgP;o z(CKmAw6*$5)t?Z!ys0fbocNWo3Eua;1ttR?0CjP^(z>^X7XVI6Fu(vbq zeyp##fjhkrqcT|!-(L=EeEX$i$L4uKezZ8zM-fMl=~94n=ZWN}N$I`2`@?mPkD}Tl zu5t|@>8-{byN-0-x-RW5m(Z=co>c^WJH=@%E3h;g)MViNaTRk05kNpWD&~55?E~I^ z7#Ybns_OEYcpEMMsvEDcM^Q;0i}UMNKvVbe14mUlekYvob?cNz__yBxYJMGlcKy|5QZw?tI-Dw%Auy zAeqIEuW912T?rVQyIX(0Eh7M`Ll2P}--FUx{?z1$MOchdQfY~2=`9SX_j~@;H9&`S zmD;NM1%8?10bYv;2W;*141peP%2wqckbSfV1esfcL~ zqN+Y#S$de>z^#^|BXp%*@qf=N7Vx~Zq2--e_|!8!o|`{dDmzU2&2{-x29eKqfS4Ko z#2L%v{}wX31H_h2M#!=5F!{TsiYW3&Eq%5`=6I>HCQ9oW^x4Peg%zLqe48yP2B4^n z-KaEoSyxe?zvLja@Ooy-X<3n8)EbP9qBc3TY~;U;DclzDQe~$)(_>Z`r=&V@#|(`k z&lxZ)CCfS0pJiC)Tnolf#t8(bHYYQ(%a4_>3Uj# z(~x*nx1XQW6pm{ZGk6qFzSdm?{hfp>Af>v_qo+sx-mvdgRU!hJfy|jgW(?kqRg^bn zmzI`?2fRfNqTsc&vSLUg@G5kly8q3crr_s~290z1OKTT1c2Hr)kw!bFaNe@>=z}3@ zEwysJVtBjN>!jPmH%qGMY%h@{Dowxq)V!tV3OYijOwJME1t(kDOCUQ@mG-BszBajg zEyCaXWHkNzz{4qVU^zR6*YGrJh!7ex_nGbUS@8wN?JqRD%(5J=8GevZ+4GR)O<|#u znMs*TEb9e1PM9mEYuLE50I?~ZV&f~izcm4*NU{9|g`+{Pv8vsZNaW}(TW-rg%-mEe zkF7cI3N%TO#WZxIu+DPYW$sq^Ak%Iy^C@#GC>b-6PK3M=L0heh~mt7YytBk3fSpJfLgGY+47^QhY? zpwP7W$$MU>eAC8>swbCi>4)A72qrSf0SC%11$%&E*vda>?E1}N#$n{or=|rm;_x8ue7v6aWE%_ zE)rpi5y4wuw;qj<1fiknyxuTP!~U-^;(X=gyIRS3#HfEPw;<@(6{x`R_&I;*+e%o< z67TnK<(MFNw~Kew_dRIWnj&V8p!eC14D8qAf9zxLN-7p^nOws7d$5FZ*skK?{qB{3 z+K zv^!gr&hb9)sXmjuyBo*zYeWhpLUXJW{1*If!2funX?NAXDb0fMgQvw*TD9Tr zu73u`w_i_ccLPyYE*mDLqz~%_AbbsX0BWNnWSYFIBHuXa_s~>+E%d~QW5W5Plfw1Y z_O{#?(QK<{q(GUpWUbZ3Crr&}`}vCVXk9w?8hzruV((XK{leZJd8ZEVEpi}xU6~b5 z>e_dT#(U)g!aWBbhX!~dZD5o_j8#~`_$e{hhrG1$JOQsJ-QC;+MG%YD0hM}KsFErS6DXGy+M+|F&3?7trjNONng{~Kb?p2_GlATj3Zy^u)mnvr<7Es zxl?UHyIw^}E!_Tj{}r(R_w?x|Bf2U5G)8Wc2hMs5NicC-g1OH=i_2IJV2pqTIAbC* zMA7f8u?zBjejcy2_G7VWtdH_BjA*1?9oc{vD!>80eKo-H6%;`{VQA6TWMl0t(Y99Q zN_j%eL93T_VU6Wu?^c;Q=1$qJlJ~%fhmAcIPq{Pd;X_i!CbQ%z=TyPvnd^-VhxNSw z3ZAa)MLVWnf82fL7N=E^_7;T3Nn?=%vd9xxZLm^t^w96Vfr=TqAP*;wZLD>@Dud_4 z4U$;Nbi;zFPn1?z&Qe1i33NHWwQf;%ax!{GaLkVA91d$|4%NXR{~NZBjj-}AGU_VS z<{UgaOJ|W1<1Zv6sb0%Yl{!M%y}sY?&(&p8!EB0tiG+@c(1j1Di{wb`J>=Rd$5^P8 zN_cspaWMbt7m-9lj znS6iB_J^W$a^1FwlOYZ61Q@s)Gefe ziQzL{-mR;ZfS@t%xtfu~V6Qu*t7v6u2!% z6@_by#M^5`jc>zjN6%QKv++q9VEadMCU(NgN;jsFBi-HG6&O5G?0ly_|jE0GKzbg(B z!}rbNxU*^&WMMcb)t9tKUKDB;XpS5-mY!fdSD#1DyT-usUnwioP^n$N1^fqa%q!IQ z&Bi}shLvf~anY#{;SgcQENf`%D#(f$@&D$PnrB|;ZGbLZp#~oYxes>J_!?~@8ZP}|wwK`6 zQ()klSU@#(Q4%+x*`~5(pKk$#g?T)2Rh2GyU~z<@@n8psRg3-&eIBo8vZp?NQWg?Q z=vOXZ*klFMt&uXAE`f{InVP7$?)w zdz{w&cz4}3)2M|P3tDWs*9fB7jVGG2Y&@qOA~$jGg9{laifT^7cOQ!o*F} zeS`(E_@>D>$^RhD5n6&eUN5|LRiJ_Z76f|4D~qnn9zSA!Key#%Ke|l{tDU3mQ`b$W zlL=bZZQK22P)565>WUn8kj)5lDk2;=sQ&7_wEY#V@ ztY6Lp+%p>YAMh3a0fSMa*tjdJhxG~Gwos0Wpw?mlpMF}3nyi8)(wnxn2FBK?e2>y7 z$-PQRydV(~l(1zS$M(u_xIk9l`5R3e$}F#>pPc!VP}A`Ecuxa%iAdkljfW8}^{i?>%-?l3d#2~9zFlkXYrx`;Ub~bRIonJ@9@68nGzt3@z|+O#Q;$*8)R{w zxz6NzBFGZO_DZbZB18$mmQ^g|%&uxKk&BM7g2q`IwSHU$5HZkIJh&rtwB|U|h*BB4 zWeC{)B$ZCZa_2i#cCY87kR~7`)3wi<^{v}jIJ)64;K`wK{^|1yegzC$(d=r4fqNQSw&iBmU+Qg3E`Dt?;J&1v?#ERx1#TM z>TGPkIIJ<%Qj!->1Nh4)zHL;6GY^+|X){|4$-rSBA=K3~@;?k_NvZ+(r`0RISGq_a z2g+gRPeMV_z(Wu|3j49W2sLkQE*V!o**(ZSibp6Z=SdJ9OuWY+C?YVgqz+zng$IU? z$&RT6?j-}+PbuBGQz&Fq#Xk-P+j!d3GXIBAoU0OY3Iy%fyX9AxED?H_uER>9L0C5AM6W zR$uFyY}P3#8t5n*`ZxLnnT~a#&J|*XGarpafblhS?|+iZm1cl@FYj1mu)B$a$;rJx z3u`g;%x*eI{oc03kHDgmk%;3tnfwo@-B17=Wb@lzAxLOrim_|sVClM;1Z`JCzhr|n zOM4@tTZt}bDi$44;Wab$$Aec=ZNddp3%ky;ak~fVP3Yvs?Si7?YkX-{H_2fL?!&7; zB&*4qBEDk%bssg)#Y_cvgcp(GKx6u;E97iiGEIEs+%LEkBi+OmP~Ut1gaL0coq5W0 z(c~!e#f5DV)LyBEeIV_jnKGbo!%rkR{oJw7qHts#(d5?M4Is=c^dz^&UtL#k)x)HSXC__w0iJ;^1`b+SBEYtuB*lAG zwvqn_D7KXjZDcn5a5%~r7FO`l`k8-k=rj)mU?u&@1+V6XbZ|X+lNoo( zChr}mnvK(0;S908S<;y1G_bA6Yq8xq@NEUeneYA>1(?31^18-*PKEQ#MtwV7!$3Xm zrs_ytpe6~S!aWdt)Q9%+)(&m#eP#gPE~X2* z0cb+ad90y*Ton1&?0aJJuW*R(t1*;CK!R;jRO*E2b#9(P?VC*c1-THP|I+@?21Eao z{{1(5hxhiIa*tKqQ~bOZE^Y(|b@;os_?QyNOPM3_;cUDQv$ep3haC}O{nTJDUCDeL zw|Xh!950nTvwLX&49PEyET7R0n^oPcR6RbzpoowNn0WqwIJhtNlTkLs1>Sn~#GT-1 zG>oVFerXD*%ysPj{PsOk9l#A!=StqtmkCeP65(#Nc2BcszcHP#2!4HF86IYXUmL?Z zIo(aKw&PHL9_?F-jAJL1x1dvcph4WmCw0hLf9qDy4cxeLq$;L-{v9M(cBd;n7O#SSG0AlD3oOr7!{+nBbPELCij02erLuk!q_8C4v(EyCW4GbwGk)i6E0Isv5eD$g} zMp>mKS4l(QN$9WU$LSCBnIXvvUmQGT{_W^_0PNu4oDcp(aQeR%8Q`X2rep=xaA|%K zQ{jZ&3umVAFtpvzWgOg@s~!OqskNpoED9dau0LK~RozV3#zn=KMa45;iz-jjMsjfBW&-)Ob*&f?us(0uV2ayXkGLgX4~rmZnunhMDy~RHD{vB`u+MaDf0!TYWRRMgs7ySFDAQwX_vH{BJQ+#>a&pix>EID zF8rnbo`}-yFC4I`tm1~>X6+N26x=V+FwR>%3^~>5II3nD6+y({TYY}QfZa_OPg`4S zeEs1Klu~2GR{&Tk>|I07mRs&dZP4ctLzeWZ#{K`l5r9~A$XkE5VOruM+71`+ z&V=nV_ZDyJTTo`1a+Y5oZU{!x^7NYXFG_5?W#2oh^@HeG%3IRfnCot~M>&U&$a!V4 z$bb6J&pIE-n46Blgd`6ipuHBLDl+{em@V$;rp@%bo;dLJK@X!)Pa8-({idB6AERa4vXuod-7>+p(B(+*Yar@;KY63 z9zec{O}Va>9*?~TeIK5M?i&s55^3bi59n@C z05dPK!<#X_^1%Gv82|(lfJI~MDUC4vhe6c1`iNfCnItLh+i9yy%)@isYmyu{uD;F! z4Bh?`z@wOCEOu!1WhL|lf`e1P6j4KG$JUV>fIJ_N!ZuCfTIHAwWcukvPO|BVsSkJ| z?gJGg$7@dL*(^&^q7AY;=jF+kQ*P#fxDXo_0dtwN9Zi5~{oZ*nt-lI3LZ+SH!@;iO zKz;ijJ|N}u`~)Y{XgApFuNa_G;5pi25D{pd1nXlh>b0ftkyKI;$;SMhF!UGAPoMM zp?I@k`JH~9j63G-oV~0kDfL`&?{n zR`INZkSIWxt3zwvcY2&I3V-$2@ zzdl9kq`gg-8&*e#Rg{KgVE(!xmvx(fx{+vP1K>SCbco>41XyGSMXtNwIo;8db3d=p z%dr=k0uf`!o^4iN6nV|o_jm1jz$%%;UCa?WN?%d+A6iq#^09Tg<7>=APR)oz%c!sN zWl(>lxz}wmHcqKre#N@a<2(BRG$(GNXCVQ$8%yk5ZN872Ne>$F;+tbq%IiR7pinHV z`#k{7TV!`BpiK)@(QR+O7R080LMkGo#3OSbS=_gd;7JC-0irFTn*xYBmUT<0CP8aL zOz5KVCFYm;;)&7ozJ3N6>NlM$mPhYx=a~z0awX!!*`_vSaXeIHf!&0#MY5RfrmF%}U$E8ME^>HcQRwvJwOV z-ENga{dI>+h>~bF0Hg3SZc_ zMXR8-Ha@*pJz=DE?)pQtL)08ZhKvI$t0l@ndkH`<6shZ!zfi2-2>n?1&5=Zc@E@-ESN5VcL$LE1y!&P2w{=95V@Ed8~5 z7=OY)jOF*yn%&aKd$KJh+`nCoQ2~@50{~q=-kfIP*kCN(U=GSmX;srAB*225tFJRxsO= zU&K}2o6$xMYi4PS(-eg4a7OFVRAZ=iQ0SvwtvN;jqSeNf+fPPh-F`b`LRcd92*3jX zN%J=okLdUQ=aQ!#LuS$M)ZYL-na5^-d_>eZ?sW$>Ai?iXS{WlP({D(P6bR&8pH5&lAp$m90)c|4Fv#+R;>y`g!Nwr77|~d)Jj|! z{ENK}+hIm%rI=>`9ttG9YAG!)-CrF~7&8_w)?*fbf6u^0g~K5#R>Z~hX%pho!JU$i z22xDClSZNSwqN!x09z}t$uqHS(rs|!&LdN@kcba>T`{~Ei!ePB{3r3(b;;j%W3mgQ z%NzbK#MvLpGXhLw0L4o_I6t$_qqi@Zf)M(EaCiQm(E79MaFS*w>|40?2Yx@$=s})B`pmCbr6##s}2a7%D)waN~ z5&)7eHhvo?VP$rmV^%2Ay95mYZ~_PGS!$(NxE9y}*i=FLGS~R~kSbz^)dRlQw1uMp zPBy1CjAv>@xiPrS9P$2z>06#0&USzDxa0!}`W*(AEv^N~k=Z>Z@S48PK6HM-IZ>Q5Voi}PSA08yVk!$w)+bytJ)eRr0MMK^7IaliK z4cHX0T54=#PTbR&*WF0%Sofpx2oXJEhrreXwXCa<7v(486<`SsZvahm!zUPv3fT~- zjUjZB>|=7^f<^P*l4FCRQXM`!J+Lj|)CeA`;e@fRZ&tvfKIvs!d98`!MyS3L=g~}+ z-uJordiCiS8D;Gy!s$a`sDf4*rXe&ILpG-m26g>(L;|6hbg@^?6#W4AK9lRp9094_C?_q|>PF||STAnM&q zJtI2!dCfx8`90O4Z7>-n4Yo1BYkS2GyTy9|nE+`t(d(@wV%3u?=U6q7x_3m zU1=^W=kEz;!SaPyM!nhh08x6@p16N}o(gAM*p?uiW2iJ*&p{L-?Ms;^8$0@0<@H2FGd1tYjGwPuJ!KB3NVM<22M9>noc6z+7qc-~myw6pN>MzY zRmw4s<4MSQL-pnN1t*1TDsGx~$&}Lq2_ek7NfCjKs$-4_9Y`@z^aNs-AMHFWUtcTp zo;TS3phZK@yWciwRM5Ns33xJzuu2-FtI3^AzyH-LfhD((r0KBCie$>Cy7b&*RL91@ z(?lWqO1P%tnM2nKTUyZ&$uhZ5a;`kGzRnxxAI zGeK2Ex;L*DgSPhv6`TW@JGDbQv#~13$;C$M@m^q&U>hP!wrWRZd2M&n#||9rZesTT zMg_`|=vAdaff%6o=VW$Ht-`da__XFA7|l@&s%|=bKt`@AO0yodG_4a`utZ^;6gQGw zvkxWh91{Z)H^W}TXU>>ZdEMTVht4>i~N`2-P2 z!LeK&G7Cb--6eTeUSyO8%xDeRFbM^+gq)6ox}#Z%a@kVc<}V^>@8nBC0ZGhgj9nDd zbS8~R{PAitCT7*S0RiouB}nV^(h7vD-XHF?gMOPixK>)Bv}X-m{`#L)2RXClYGL=8 znZJ#(Op)ZiH9xUe>u2Yt$qj5&&poOFxSIf_QBrm^mL+H$T%PwLoEMWUc>^^&(lVMO zWB~oV>hJ}M%8TU0OE@?aaPv~g#Hj)#;6vE@rY%4#4GvTYqkxK)R#n+Y>9~CV zW>8)*psNN%04t|&~b{J0^LYWC2!>&q6W`)?pzb)n-nH`yL9UA2Hij0 zE7ck75+3KH&D#H9cw{Ren}z9xH7nI+VG@K}NSjOCh6T;OK4h0t@5kUneY6+GlS6!! zfrh9LO|5*k3X4aSYDIoQ~#&M>qGI zqvVL74Y9%+Z69)?bKw!kcJoT3lEXQ)DN~kpWZ*_CNq<3Cu1q_Ye>nwLu_W>Q`aczB z@vN}Tf^IN(w@4YUN_21}E2CcwMjVT5>fkkjxR`kS((h%m85MH&AJbOI37Z?X!Wm&$ z32))QJmD}H*de?IxxA5dIZ}G5)1X7C@YX^ zy!7BxCrY?%Z$55OmDQMixKI0*Jw0UiU-u5bK7PPg73}nLA;_;*JL^`>qB#<@q;7Hs zG;vY_fdINl+@V55WgO=+#sr~XMHz_r`A>u0Y_^&bUdc>@jQtN$>uOU0L55l!6#cmH zp=>M}pdQq4+^_;!lhPP|*?4bjRA9TM9y_%O4+{Sr;>p(bqch0HziT39x#q^i{-9Pd zM?#ONMY#UCA)GoKe@IQtBs%hk7Ph8Ay4~$p6}a zav`*-n3&iCV-04nQ7!VAF;S?31+88K<+W8R>b#bcy0&)3vZxQ}(vyYbZ^(DLwE*!+!fT#8XcNNPQQ9}PdUe(<`vR#ca zY~KQM+ljEKATUWXHb9zxRh}=;=~(yZJoGt5>g0h`S9nx0Nx~eV|D=?^d`jSk+`Q%= za^4U1rHToGhAI^7$hm^W%?VS(N!=N)#Vtn#sld1q`2>mg^55IqeS%!g%>Owl{iKK+ zOl@xLnN~3~27YVCKO=IQ%0ju=HVU7FLY||PN%PYcvJ(HINS`wJBRvxGw_wiRwWYMS zKZ?o7q6Z3s4o6eR0y{Vb8xe&s1<8o1?!ddQ?M{}tX9l$SRDT?}F3tWO=7>=_WMtIs z`zaD{F7Yo*mm=KwC&@W`6WU|-kr;91a+ya?ros|G{jHeH$#&M>7Rs1t^=vamjK-WM z68G!p9kcG0J~lSG*0&7I(qs|hHFEhiHNTucT{`QhD=y-3D7rONrkVz9qp=ue=41wy z5py7=`1HWg9&HUoPu9EWFsGbRkajWc)}HJ zG9>)bDs*UIH;K&hMMR#V`-ZgKTB!%Cnf$~k!qHu}Uho@_Thi{^kHGYwm=w4OPYgCA zeVezo1orpS2mVR#TJ@_(iq|9lVH##}y{nC>sPU}0>hS#UsFm{dRkg-7*#oEG#dcxC z3)v&OrWRgol6lAe|NbEL|NX#Zq|51P-$4K8vA?DA(ZoucncK;9)O>C1%gSlbv zgZkC7`$el*@0W8T)b3*i1Hzu`jDM7NyAJoiUnST2zf9|ny zcZfL!L$8gC&e^)qHo@Hn^iYsu7({NQ2UM|64(arNAih&JW*xd1_;g+zIEw)h(I@MW zu_D@kI6j+b3god~FMY#<`fPbzsJJ(zO*plHCW7%C1ASl_>1aGgtVu~n5^#un-*lcj zcU6B}D!eTChda_Ofl)qr;6bDdzvn)r@4m6}3Uj)m zrhl%;l|{D|D8RIy=c;h5iYGDd`~BqSuDb37I!n&TZ$^G#_s>`A+z8b8)q{n{Z0{?T z49b7HEU>fMw%@pP6EVXw{=}Z%M#r*5kCOGK`a5ly%6}hG&%giG{p+y|_tIcfJhFm? z?9g3GY(@JvdKjm7qM+z(gmM6EoL9l#7AsTN14v) z$%8vp*2{}QwK1qpzW~++Xrmq?AH-^Yt$!}69s>yvpU>Oaz2Vi47O0c6w@la)I@OW< zbsy!KW9}i}I|h&2I$%QTqF3YWWP#ot%oO!I6J)>a*5*n|XMR0#fSk8d$fSUbTyv*J zxxtRIeJ_&GHf+05aBMyf7DfVS3Y%0D7rc=s*{672_Zv12(N*0AV{nO|-_cY42`{_N zmr4fT#(9De?{I!-oR7?@;>;Qf=6IegFk<3lyXDDOOwj$uPNt!?c&_sZ&qaO4`bprM z{T!Nqz4m4%kG@{DY0nyrS;|bMWXL`ns~z*lvDz7bD(rp6ee3S1H)@OoAfAK14R8Ey*V(%1yCA znUXo?w~6EX-Fi*(!R6DNWcje8uQEg5V~)g^fSPKE?i)z)Uyb$&0FkwIT+h(;AJt`0 zst1`89nW(uLU!`%?L}_9E-_RTrJ(-3Ry+GuI21{dFMhBkr%~)c@X{L~S0Kb)&d#yFjGX^?BlYAxYCUZ_7DVtw_r7 zt!37MBMesKhL{QtNTDWIYgj`I#4<57oJE%1Tj75;UA`XE4i(lOT?boteK%uq8)Ed{ zTy}aK9}53^)v{VCT`zyRg_XCwq7t=O;h*sM8KieuDw1{mIQ7np1a!RD3(Eg^pdWUi z01^?d1E-)?T0e%e84TiN|YvoUAy;UAK!A<4)feY|N7tVe++)q?uLINw>qU z^c;{6oWCY!;f4V16c0DR@huNya?_;^W0{o4aj9s|h3ZRZ6wf)Y4PpSnpknNgW%f_! zZ%SVo=)w;~C{5b{I+G}+Wm2wEGRLpL?!qFMCDpUWxgx?yV|+fmvK^}HV7d(BG5UeJ z8KHk#g~^@Fv1iuR&o7wS-Q({}l^pz;JdNVtw>=VD4ZX^iC`sFjBYlLnQSKMYp$cs8 zAD9uZB#zDS$0x^8>N0l+So1mV5Q*ZP>e5P|qcXeJ&0SP9o#tbAGGL%Hn2=OE6auAy zS=g?1B`&-$@Q=2%mq*_%z{m5t?mKk(dI&^YshDh_^UrP1_ci2#&Oq4`q-{k*4~KmM1;tc7U%srRttY>>~-J{v~93h=7%hmU<2;>A-#q zpRWhPv8?F%=YEE3yDOO}MN$VYU5-?o)zlf@0eMZUcZ7nWHFDbeqGvMI*30WFW}LyU z9iGz-|D;6Nx37Sc^(p@?JS58Ecl!A&Q}*ozfDgtL^$ytJER!dam)hv#heSQem%|E-Z**r^6=F7f1lqmImBT-kVp1)l0 zn|Ivxvh1JJYYg*_9fo9>#vmuM^Y2$51Z1~s2Ua&Sm*jJJm!6&uk=$qBJ(y1(ss!FdOyxY+--T}H+4DR@>U}s}q*<_5qhmuQ(Rq^etnGC# zAi|w&%lDA1fHhEQTRX_7e>{PvgBV8?aheD`SF~}KN*WfXw$EkGOGp@s1K$mlDmj6U zt}`hhW_?b|v}CAXq+kEH7mkj-(8Nym0ZW7(g8z6^ka^j>Wx>SkMK{~&Su#t8w4Cx5sShOY|O%5#Y2ZGSO&_+JJ0{Xz-zQG;~Y%0rC?Zel%kq<@vM z1JAP$co9u-rS<>b z1vm@e-bQVWrdwcS*3&U)Y|VToJfF$AW`A7YUI2Os-kS2NAo0lkH@w8)N2!Qv;50(h-H!KiT+ zH4)*)8&JwH6eA<##NZ0f;9Xd@E4Emmpelv(kC_86>0NdTs3FU(c|^1l_Hf8($a0^9 zbxelca1w3Qc`CV`u(U2T5OA8){rXVIK9*kq^u$B)XlGlEGq1a!-#ew19@mg}`+|XR z66Re?O?Iv&U;o^){=S5j{T&$C5llTU!gkT*%|*(iPgj@PC8`&OPw<}Ac+rL_9Ul!K zHi*~UX{_6Ay!V|}l!<0t6zZNyq=Z%vRlHp}QOi$2%_oiXR{l$t3?#BXA6zss zYg>(~=5}eK;SwfAJnvjwb|OY1&M-k^ujv@7s~6+=j}^M8=Et zZk5eQ^Kc|5B-2CEi%_3U>r@BU_=U800$O@Flq}pl{sIXtS%BPL$Y@a~f9AmC;YSJX zSyZPKPPV_Mj2Ohf#AoOw=H$Q(<8SEA|SSj&BO3kqm9icP@ zyq9n#NPmw-&a$lx+EV16>MslFc1wi_LLaNo9FX(0#_V~C=>V?*I5yq~77^!lYA<;w zi;*UyWxkrb5jNx1eP|1|6`;tLhX$3#A<&Bc)|}V%1d#df*-dh2F?VX_L*!l{XnexUj%v_c0mAi8r^Ecs0$TR7|6&WE*gw)-D?TS49 zQJ|j1bW3ubp@KeHuHkX4ZPhH)>3(H&AE<}WF}KPR^yw0294)nfYAc5T_qs}l;nb1- zIp5g=os_2M-HLE{WCEA4ne{-x5J%3)73#ywHSL*I^b2eJU4G2zz)@l3xFxhTIC1f` zsTDO|cC1KqDFbeCAJcN}HLrC3j*Uy?F4T$45nU2vue6A92*G57m_T-qT<|Jz21*gq zw8FGCk~*g@@<+5{BXukCzUsEtq1JQTLHaZefB)&*z0tgmfBLN+0OTF~<{zwPM2Auc z`uP1d?MW-=3xy4EOnn|wzCY{QF-fZ3?Tq<+ar-uDuaZ%nB8-~6cYX1DPOs*e+yFkY zU&-%GZIGAx1JXaQ`IxX$1N^kS#6L`|?X1^UWiMsZs1W(q5yaqcCN#0itp(eov9~wh zh)V$opFD&mr7)KlQ*YGrynBDK0a&ZyUk&9Istgt?TR9`CF?2qj-#M&a+1icI__Uy) zCoLBy_c?NRZanR_By1e=v4)e{nxFgDmAcPg0Q6{hGtwTVJk)e>w5gC5bGzJ!vl0CS zSnAJX2TA0zm@!V!IdFew0@Jp@oqgi#HM#HesJH_^6Pt|M9pFRhZRfGcrvu1&Y1n&f zqvfyjskp2s<$9sN63yq9Ze^=)bbF4+<`xX*_{CeJrA~?sW&{rIEi7eEflyp)qKCyl z!=?4m7tR2{UPQCi%73bJHGbC*$NVBsw{gFiWH^=Tv_@~&bQrxrdC1SJ4+Cfw&sAr8 z^Ech%+iQkZFq5-v2{TS8!nXJ1=N+NyjG(Gn8 zz=7OC0sZ&;@J%05E)UV+s|l52f_Riq%y&MQrTXiW;&4v5wa)l;-6r~bayW;ai# zhw!zo-@zhdu|z&K>X7@5ys5I12jolEWzXY_g^{dH#&N~Y!9QjyIQfM893$#Ah3S74 z7e!%{-1@UrT8{NL8&$XEt6LOT3l1%xdE^iN*>?R}ZIPSdrb2?p)A;0Cqs6DVjN*^Y zmOo@3mVFmN*HG)(8TFBk$9RRjI!%u9vnwk862z z_?zJ3&*Xuj3G1cZ&@7F5=Je?0OyEqik}Z3a%snfAp5*ME*yc~;*6>Y6&a=-x*zlCkul+LD!3RsGFB zcQ3Bjm?3<(K!cb@c)Ye>#kv)5);byg)iWUUJpn-Vqe0J+eL;>*Y7D0-@TW<=B4n_Q z*ayflS(3B!`G>VfA8f2I=c1FEJTIVuFc-6FCeseB&#vY*B*r5vL@EvV+s3FM=*=PNLm;&{jE2(pNqNa~=KTD_<4K4pO zVFIX3AeR})dG5gMb_J87GYL#w)~af-2*UHY+@39`UDuwNm}*DQ(zZ~t3+sQ-6SNMl zrtznjXx=-X_tDH{msw2Yn|syZvRwEqkv4?0cQKX{7*ih>*-+R=dh1RB^*Q4^ z{->hu}jg@xP{D)?{_wD{h!>wYg$FKH}pyu8XB>+!y-`U>A=*j5^t`E>Gyp`K;U z@ImW6^ABR~8QIIpU*zHtGigWp)cyE)#%=}{V(vvuaU~W^jrVL5!~pX9s|{Nfa-BTo zEJ#?nP@PQy(LqtELNmoP&P47iXO)h5^xl3o_H~3DZ!zS0I@tl(2QLeT)CyR0@rS*_ znt9r7^XVw(tOJPW?<4*pqyDV_cHT&ubDV9`kqlyNK+Z z%78FQS9)yQ%Qk)hZX7B3d4c*avA8DQ`ZKKh&(gQ}gp&|3`^j2MdIQe=1+Bp_Q7fsX z>(~Y7H}Gvs_|T`EOJ{(W@;)&TeZBk&MbrDI@{%9ucR%WtQSE#Uo4KPpP)nvDraX-0 zdL_&={kR^V9Bksno)P5X#?5~mj-V83R)c5Pr|k^%bVu2Z0Z}XwZeELn3nRYs?Tli& zhraeqDBKVpsR{i82Rn<&<^|701wUHo@@*KP5ZF|7LG}D^m!E!4bb`hM=&$zT6rSp1 zQ*>i0-GQha&&ZZ(apu&X`|(vDc6BHaUtjN*xxc=C1!-<}xoT{X1I8sC(J`h|+%>

    VBrWQe|Z2s?@gcNsVjdRL|@OYngAr(+@p7& z1?ntYj0bPi`d|G-{oG3t$T;_FV$TD-*ftIPb3U2#jIQF!OHlS!vYXfSe1m+JE0trJ z*vR(p%%0b|$!EbxfT|yjhlt%`Teau3%|bhAsL@{~9)^F!Xg0b*cwKAZ`N^Std`pOj z^XE$`*{tH8NNQP$To=$S$OLzvp6-k!f)wY8mQ?R;gVXGRSuLeKtfX$z#}9Q>(w=W( zh!xG~^;--|AQBgC!P9uf5WVm?RDpGQQe(@#~@~& zy%G7leVO~L7BIXryOYX(J=^^r5g!>(<$_&>=Z9)7x}T@6#})o=J3{qF|3jkj0(U(FlM`BJ93o(UISoUqN8AzupLsB$GcEx&R8OvNu}h`OO8oa?&c9@eB) zs9&xO7(&tES!R>*S-l2`R!@bTmwmHVvwAII{oLI*vV z>|$0-eK*anI`|^IeHtVuxg;UGxd!y0q^OR=&BlIezpN?>U*V;m=6l9Z?e=C-y_myzISO&N+HI2bthLIrN7sBnL;K61hlGUoU zF^XH~Bym%3v+kz6>amrtE6^ePDa750;=ZM2DNSvMw8)Yid%j3xAI+*12`WnHT$V94MajqkC3Qiq1{p5{_u`&&xK3zd3QQMeu6Txj@=Sl;0DH(WT84|_e`>XLw%oSvjLic3)v;(p|^igVV+ z3?JBm>X$UoSJ^)P&T^jH>&gx(=n|s;TSi(w(K5{RmEPfS{V9=jmC~UphlkQw$aq!& ze0o6(7No;XWRbEzd80dbK^H5682wdh(?J?XSA{~htI9x?0pqwmMsJt(Iu$aQEo|?& zWI^e&%zF@BKkM)TWn4|`?Qq!P%ej1lYqj2*n+UmJ+7WguD=pGHNR$|{P~@5(B&18piI_ob}CKiOzVB=6~2enJ}i$a{;CiDKW*>mv9=H=-h%TmEKijX(&jpjDU8RrOMT!ajT)eFUm_8s$Cm zgL3+z^wZ^)BXoESP)qe4iz+!EYr%aBs!?a>+ujlQ>6PI6+jH$MPlwiag(i^p)N2hb z7jP0Gy16UZCUR_e`60QTh?Md{T9OpE)G#Xy^){{T?R2DK zfIFgALYjstv_l+}$8V#A^m$Iz0Mu9QAQti>`$aNvT!QQsr{O6R(;P} z0mui^KfzYkRO|6kp*)qHb0&OE#*91$#*oWog?sRn2!He8OZn-Vtenw09^DHVLgMhF z`{lk1u`>cOEyvBrj<2&}d{a9Pq$#X}?Tf_)1dK}2C@8Unv8Jssy={2|4e!&&z>a@| zmM)_d6(cw3(-^9{5y{++1X5bJJJB(S-Rk+{7>@ORp?qDgD-xzcdxO8Z)>g$UGMNGkav@%Id6xuw^h z4B9Odn$ohH{p;!G1Q8&uVU>++b>K776Tlx&1=l8mgpt0IS*Uba5SM`)4kNu&Nr|qe zqQ=KG1FofTpE}}gdWV{ zXa3?qYy?6FDy`XbpKIck?j1SoY;@#qZT+JTi{Tz0M>`*owkM>ikQZE6j$(Ctg28Z# zo!KS64dDrfZIKOPHN*e;Js<5sJhTvc5H^jFADD@!eNe%zGie)T*XT$y7lD4#|1G36?D z*EQx1Ha?hM@Hb9}YjVRnp5$WX1R1CSOx$zeeeS|qj;x{hCi~23?MX9?PP)f|?MHw> z03+u<*J>G?*`Y$kd;(*)TT@gi@pFs|=rl14kzE6yfdD@}Wve^=E(KtBQ?k*})tYV&wFS6slwcvEZ z)=sAs0o;B&cG#5x(&~gKUUbqw+>uKsAs~JEf9t`dkqMz4oa?Dx1UmKy1wZ%+n`-RW zjOKh&#b8=%*P$S0(b4VZ{*Jb~@G>%5)7V1i@b324Xti%C;r?~?WP=jM5k2lrxZW<2 z4ITd2ZUx5n9=X|MZozFL0%n|uQ>ZzuTL_(ywyEfgQ=TTpy}F5sdLuSC*dJC=L?PYD z8n+#$k~8JeGf(xiu4eFLoMVRYvrTV@HS77|hh1DbP`!(W8fy^U$<%jD zs!nRrqB@E1I>V6R(~e=cD`#HhI&RmIIyOw|{?NlYD|o~zoHz0z;7 z8%_ujit0oi=SM4kZ_=L^1~T&olZK!)Yuk@3WJd<=C-13P6LYVp6615-w{0W|D(S(b z9`8^yB?5tjhVC9&wZbG9l-Oxb`5w(q$Y`p`AuszGa$PbuuC{?2bQMEYt9%xI?OmuCG{;Bv&UslKH#vkL2DChK2<|0OSxxW(= zVnOWYAfIH*C%f0HplTpr{Xiv(8+H=D ziDJv<*(Q@`xSd)K8vE4NTtl(Qrs?*oQVlVZ%>aorP7QU_q1>U9hT#SG{lTp3UmA6&#EXNj+Lzy2dt{Yv?9?nj(G?fpOH^7gq40t*o9ovxNX#%ofe3aXie5 zBHuyJ8;gWTk#;AZTCVMK8?@37o$CBZK@IgA;n1`2Aw=ock;y-#B3eaRqBb`z#7DmG zUbg%GY?ap1;Y9KpRPp*#AN(Dq{B2kkeJ$o*<>pv=!@p3eiy!U3CTN0LdB1eWdF3>H z+Zzz=VBfGij!2>h>#2n5r4-b-yxk&S4`6zh1?D`+_nvB<4s9gy=LrZ;KfihL@5ibR z$tNq=e;E|7VgKoVuTmzIFIFs^ESQ8&%Bk_F@>F+1M+>tt!&CHBYKL{KNG|2 z4o=MQT&7cUNG)v$@M1RzR_B;fTYv57x3%2#_|G3lq`Gu`e!%3;=Y~g<%CcZ7>q1O9 zl(@U8PYR!9V5)3?P;}$b^V9N|xbnZib9pc<_IYQWnG=wSeDgdgzt=*?rmwh|tcdj0 zYi>=w*9?n`)}T6Ije-q1S;I$EMeA`pP{DV5l(ZRUYoS3 zLUX@W&8FLy4`Kr=2GoEBY@J{rPG^Te~Btu}1ef%jumIVg>UoAku8i(xwb3WD9V+cgWPTGp^amZ-!k+R=y zAWv9ZRL{qaj%TZ}q!u+JPwbC}?21Y3q4(65#2oydQTbaWBKx&NFEzJs{o>?u!>{h8 zt;Byq*?6I6I#z3FMA3?BV@|3+!QvgTP1Bkg3TX<-GB$D!ru2K?_EL;e-(IJ+IBm&^ zZXhTB^7J@MGj|;}2~C2`!4r<5+b~zBH_b_bx?V5FUpdU94cq?291dW%=CxXbT&k~G z_tZ#J?KT&`am+fd5JJl=6l;!0b$0Hrebip^LhVzt;X;zf1n-b>aBMC-PAfN@cMY4D z@F5E}n{?RR-dm_zm;aKc^#*aA%A&=gamV7x+^>w+%T0P*)nT`})_7F|bb9GVr=M%J z9J>7sbn@;HBZWVXeF!bmM6QEhYuFyXTecp~UX*rT3Dz#Qh95|ATiNS9x>$Y0$uHzh zVi!w_-I2|U8tnS_^4s>wdbjrLftdNemvLMy@3d-(j}k=S*7NKR`7HUOVD@2*t<~(k zKCL1-Lz0->N`o8|LA;a!jAZ=kJhEC;1>j$uHof+?ajyF`?2`Nz>3^8t0p+y*P^xV&u@|7k2S6@X^#mv zsYgv!2MsTa6R!Mq9&pmQa!w&Y?az)WzL3PXaK+HDYmS(+;>76B{(@b5alkOnNAY8O zmBAK18G&)oet~6tJFi#Mp*0qwS)Bz2jy=-;Qx^T|? z0tW>Sns8on+~04X+S67ZBKZuTc^-=Id}0Cf6mO><4 zo$mzGP?x(91)liutly|^8!!#-(u-gx_+YgnQuZv8lRIcOk+R*?Ff`PqYma-Fvj>lH7xMa4}-b zcw8=@3Op=?H;R}0(=4JICZ;A2*%0Eg@2v>Ue%Uiwq%O)kB_Bb>JB@}h|Psjq0A*e5Wl zfLHXXCTz~0aw46BpTYPXC5iH(IVv1fA+ksRI$WE&hjtW66ivN+uZuX zR^BO#`Ks^q<&}*q_1oPhiz5WL(utol`jTvqB=o2SJtX$F7w{txo)_i4f6?p5DMVdl zMvmrR(x{6tDl=YLt)NcN{XNe$N$TR+thszTM!(v6;zwjyGL!bKLc*qv#Ou&31$#RIN5FUmB_K zhpL$5=CqTqw3kg_MA4^vhu=}m6Or3%!mO;1b#(tkODOsQ7Vi1*$y1%(gNnHE5Xlv* zMbD1M@v5!pwoxODkeaY?ACb1<2w8VL0l8}akKh!M?~DY_#;kqS>R@>kGXd73WnEZm zQtP9uBx;CYL?L7`xHx$C%|nU*x6i{4$DktHzDoYUwSgFos-ls=i6^&~!V>NGrfeWm zf~1jY!hJq_^C8!q#V0$&5lATtnM4XJEQ>Xm^2L0@@?j5$EP8qWd_K1#7zBvR$*}FXNPaNl@cyIq3qXJ z04bCX&W*PwUgQ&F&D_1mOgf8Cz6tf*?jPe)=rhuEv2YtdzfPIK$q~uk9_E^8m3Y@K zT*^z{ljilvsm%9M1cB*8I`U!7#02kf3e#WfU)f+QbEd{=61<#a_=52kS0OmA6o7`xt=O);uM)rXv_!kZuymPUibx<+t z{%dh)c8J%J@@(VZZB-u7j9%@jrLl^!);>|wj4+7}bFf*7l6%xEM^Ktud~oJL7nyBg zobx~&*LRq!{TLT|Q{wi?LbRklg_deOry#<9>jhkEddnWvrJKxq3Kqmzvhl6nN)>kI$`g&El5jdH?p^JBrYrI>x@%A9a*2Qk1EfB0I{X; zj|0pI}{=_Q^h!+O!FqsI)C#`s{NQHF)Oaf+TIN>74_)Y6ILWBCQF z#a0!y40p?<09QSd6wXE0wUuS7N#4~|k{+u)_w!%NN^$%bgzH}h*kb?pQteUTX=}Y& zmoNB*mvuO~&%L9eFFfTgf?W2QDH_I$G;z6e)!)Bdm6{@wWT?h9-!=h7Z=f;0YEd*s zPZyjPs0j@ZFMlN;R~J@6?q5y!>H!~9+4+w7ZD%n#pp21@lO;lew}U;hecZ(KCeP%! zG4LY~dPZ9oPTDUv;6@-|fgbJaGnw1F1HV51F+U?FRS*V$6_^=lZ*`U9kL8D_)ed^0 zFBV#vLHCbFl|Bj^$>PNNC5z$t>qHgpn3%-YZU*(eqP?)qUP}lrYQTUR^Q=!x9QZ%9 z?VU|Q!o^X*dO^d+TNKG}iusmgNec2JC;IOqdIaHcWN`FCXci>RHD_0A2-%*Rn{R4w zD-zjkc+E>Py%9}r*@y=hIWEtY9?Qgt8$5q(n!yt#@TXE!d5nBEJ&5M9h$P1M5tgy1 zeb0%`artK=4fEj~H1Vj&;NYNkOI&D5%HmK8m)3F8FD*D{*dC12oKVGbiMgWiZ(RBZ z-(37FgEWambRjX%*9M? zJJT^+eJOuu2P&U051cjzQ-ajgWacgtg*{owzp+U|&R-@|0? z+|)psRerd`7w8UjUp{|iy<7gBuW&}_A1L(uuB&_lYDljkL&$}2%dY~P@LPw(q$KII zFYlMvF_bwvY8^)A86hKz!OX)((H4i4eESOK`2g1V z-vPSwzXMqNe+Mv>{|;a?|NEi%<$wSF|9&4VH~HUr|GWONtqHn#!P)g-n=3cTTJJ@ z0nDfON8E+)Fqv*j8$*{Q`?|br_g$PS;v>T?7(XPM$<*E;a4CYi)?jW4Ykcq7KnQhT z!`LwBV0C~$cHf?tZj>)bmTUc09dNY4TqLP1qZ%W(kBY@&aD&G-+}35(Hjq+?$>j3g zeU70Lw@3#ThM_7z_3SnOZ}#4s1Z1j!RdV0g*~04Me5Ydb9cZ4qv3W&_{F$nCa}YSU zs2+VE83Et2L(A+@W2EUt*j8Dhs5r9ikq`F%Fw*LusnCznt!e^D?aA7S{zf^I3e#E8 zi}e((D@FE@ASkI#qiCty8BAGp5_5ikTOwj#I!E-kQiG;GF3}&w6l)5*GgrkeWZI|_ zg`c+iJ?UE2LIt$a3)_ zL*z=shMIO?kwIcL-|moGw8K>t--ew2@)sp8sI6y5$kP~;R@ z7fJgf+jUS zLD$h;e~eO^KIL1;Ohc|@&&G)LXteCm77K@XqJF?locDZS8&*#nTq8H^Ji13B(r2?Y zSIZ>fy-Wm80v=#>lzn$aI@7$4RcMnb0E0-Mbyi4V@0CYIQ)81^zC~HeE!vucKAOpS z1O4803>lj`L@vzhLU^h-2m$HId++e85mX^-ds(s%Jd%chdpQMnqfT1$li+hn&MD)hH-A?}--+^5(qcF8{N&?VwdX!DhYh zg7L9>!&X>Wf1t@*tt{tXbqtw7T<$ipk0?fk-xT9oqLj)Pk>*!!_B1?QKY`;|iOEq> z){;{n&6Z6F-b9nL_LX=P{Y<(idorH{ZE|5?eH+euaC7W( zW)RLkA^1w_A!xAjEBsqq@gYx)$K~TFyDKZ*vUno5smG?n{tl02dX9TBAz@F?eb(Gh z^xTJv`Wv|-T^w(uQk$K4e1#8W^COW{K(GES->&z&rov3wt|iDA$rEo~4tXC#m}GBo zbfWhCZE0!pyhs$@3qaS3i49Af;FIwx?gY%wg64^56t{Fx)7FM)#d@0R9+C>(qz4vT z-!3bprXE4&husW}3-fx%+IIs8@8=!~QTZ*?Vd zd}*?eADT!i_b`yOxv75OJtgCmvo|^6YGr0&PpG~VxENBk*%IH!kUyd+5(IW*cj>u- zw;zInhG6Qcd^_Cl7Fryw!5bH)Zsbk_f7{m z#oxkfjfn)m6@`z8BYJ&2p(hBA)X;aDW3~t>eznGQe`N6xKn zW7qXK8!uQq2#Jg_x?25L7pisBK8BA~QX?LE)PZ~sS^|0s+>k`O{PQk=T`#$UI_iOd8aw!GRsl9XcBOlxo@;5QC8PKor9ND%zk zVDY+!_SyN)y5vafEL*T6av21ul~7*M9<=MMI*RKMgsMXyXYx}CXkY0~zt4eW4=s(5yl z3REwV+~&H37?Uo~(5&y!LK2EoO^mj4 z87@ClL#BQ3WDns(e!uK&dQ)@(rMm^7ddz6#`S5hhv++ zaiP#ZjV9(bBH)!ag1WSDfu%eT{^o;=N5%PF7p01fW^biPHq|_JcX*JOxBp>}Pmb85 z7fDtv58^b{H@+6uAbUxK(YX5+zt#43%D~b}QUD%Pvo6*B`240AE)H+`!ao|76?j4E zFApC>@o&w*8Vc>fiC^5lpZ<2MQjQM@KL44OS> zYPI3Y1klFH^BF%4D$MBzt2gBKz2hJ8{`2|GC`ZHOWSBtJA}=IY>~wIt{6WGs?&3O|ndVMk@z@eA?t4Y-qw}dbNZ6o%Wj8m;{#MH8?8u(TQ1#H}Q_@|*C-LdKCsa|?sZjIrf>f@&07FkCZ-n$TprJs_kesFqyz->i}`kYkycPp)L2PGJ&{1&=z`SZdi-yZ;Q zShjFK{~=Rc2%b{^l((6fI88Z>pNzNlV&tKSe|pxF*~#+{!Ps5f-}fc!%T^c@#vjx* z3OV8Kw2So)HggP&D7hf0`~#KR3lWT24)*MkF6pG7Lek<% zsnEGcaly70Fh5qi+S{1!n@Wb((S7CL;1kKq*(3ME13z-OOd-LEmo0BE70;fc?3y<~ z)qy7Rpqa^-!`iu-Q2`Y}%a< zF~%W84xt`$7cwv3MVbl)jWO7E8%_LOaEH-Xcys2=H7RF3J!0hZhXlcIdLsc8)U=~U zl@mNT&KHS0xa@=nxh%Z`4eZ*#GX<-rp@g#?v3m9cM+Z}N2SNbsN~*xR^u z^i|#+Yt&T&cf2?1O-e%{1TXf~GAc$C<&-bqu~)sIEcH1Gg@0^0DIiAjd>e}@utf0| zMCWDeCnBrJN>?}>HI+|38?v+3uI_x#0J`c0vfga&nsn)$R~U*ZS|GfL_}?6- zm_RAupc^gI8_ZydFe*)2Z<|Tun2+TM8nJey2|taki;4v=e{AP$InQVs-rX#&H{!h{-%b_#WQ&v#f=f+>p6BK6BLjZf>n}y@ z!*Ll4Po$%ha636(=qTd%h$^{o(fvbQ`TE;{$E#nCo`mC8Vsshjx zS^^Etb<#hpN9&ViIVF_Y3^gH1FP8wQtA=Q`bv1Z_?Zxl(MNKYA#!u+00J53Sl}4l% zloz}q%Ca)&NG8x)nTCb=N_{+lEb_`sJ)=PK9dl@>jIn%n2J_8`$%QBaT0#>b9*2`q zoXWYUFNnzQkP=2(MJkLP9NO_Fny%Mwzze3?as{leNXA$CO2F~(eTurdcds>P+Y|_4*0>H(sgcua zx8$#{dN@thbGYxDR`UDo12A2^5t@0f)Ze7VZ&4Lz8KrZ#5&()VI~hgv=;z{S0-J#^ z1Q)$woaVUX21q}<1l0+4dNEayj2&UGQKZ{R)WFt+D2^^ zn3^u1q=HnyrkRwp1?hIKR^i$8&!pv;3APO(t?CpAEZk1~OmP0tdVlr-6cvW!DpBAE zd6H2mO2N9)n$Al*(2$K{oB HVW^y$YdO<0!NOalxKC8zpg~&bKHpd{d~3v@X9Z> zn=1ij(#O%wp;}@_Ok4XJ(x0Qi&A?!y^wP3_(3Zr5F4r-nqfe>1%Bkd*A*j7#OC zLQXHgiIK#qQX80<=z9@MAjJU*~m%4zpvuv;$(F>e3 zN*}kUrrr;B7WRLQMOq}4S0!PpXPun~%UL2_AFPF^PN-w>Q^srQ+7AoGxjq{s-UbgV zu3-;fn`f74?*Wgq6bA625s^=hl?v#J4#jNd3u_4{#^SzyFYt7B)=(-{I`>Q%KS$Z%-TrES8nMNWL7lcA ziX-gwu)W^d4>6RQVhT9VZ$Nxz-|Zoc1X8!=iiZiW{J%=*IV->^HjauUFHKs2N53DO(Z zJEQqVXVlcf6x}19gqih^L3v812W279UlmWw-g9^k_Ec@Oe5j&gSk-1{3Vi-$H;F z%k2H@3jBIXiF;@{D{r+FQ|SPA{p9mZxxQdp0kH)!c_CXCQ{!8VxWO5+#XmKDv%r3! z_V59T!RS`TxQZ))Ju4%}spT%1y|-S0z*+nBL}aIs&ATr zLW0UX-8gjVMme7fyE3VhxLaqEGt8*Md~-I;d(>^QJ)D-t-?12*cq@X& z@JRg|b$Y_4Ay=KMaAOl-rOf5nkGP)L8kIzVrc&gHQF2Ph0;a}+MeY9{R->GhR^`(DMk zJWEJKS%t%8CysHfwNp#^;jGmS{!(Xm6#EfbgQk&{=?RSdJp;WB#eApNHrV~{7+_}} zz2VB#HL8c^d`K8Q)`aaMKDfcylGkU?OcB~KR6mKe;g(%lhu%b~n(>5Oyj@VE3?1#A*bt<^FTJ>;}9H@`n_N0yY3FdR(0SaLgRzFKno z`SYhb-POER>%)Aw`(biS%w)aY_Hmk8tz~sZMc>enXH7_iC`p>c}lNL~N_!FH?o`}^}Mr1<*! z=I7_TxVU`FuBe@^s;)LLG|b7#`Bm|s_BsL?M#b}Q1T8>N+Vb>xr&OvL5*BuJ0zKXA zcd)mYm6HpI48o?jtygz)a`N(e3P2~ftzW1#da6NJ5D_uDIo-UwJYZ&IJORhm6+0s% zQHC8j37q=s1*~Uw4Cej>49v>r2u%^aNC-;B= zx2N?$AF(~RuNYBRtkv6LV`Oq~Ze=T_inD&!*=amzkLhz_I^o+ z`0kypz5PA-Y9wc>3yy@H|c<-ls3_E^0NauY?f zp@x3j-+C!9x^HXZ2Gqx>INudJX*eQvCpe%o-=E#~)gPtYwzYNEaPy>bERO#+gR-jwQcst6~=+#y5<#7&EUZJ3g)X73&NfpamF^`93(yk0+c5l z8bkcnicqZpWCIds4}+arW?rf$>&1g&6M_EI*X3gG5*9XxJ23!wRh+;C5ezLs-FdY^ zol37P9)0U|0ij=kygeE@Mfrdv29H?$#NNYkYH0kLd4|BC<3LU#0}mgN1$WBHfz$0`zC~!>MKn; zzEeuBd#D@a!7Uh&k_MO3+XHbq!L&I*WGrTOJ>T(1CX9`a=HTMG+A720^=!)DMQ?#~ z?w@S;M{aa8cIHlX)inrN)P#XK7wwBk13b#C^u)WyH%6saO=#`!;1L=AeLllsgWps- z^930#l2j1K9^O8)F+K^8SxViWx|3U1+BTP?sKIUoBLH14t}Uq3(HY`^qyYVkj^R_DgKE;(LgbLl0M|zHJqI!UQ78+* zwhsrUf&F^y_~Bw%KoHdUrqH_?46x&o`yJ(%SCY65@&(W)F8iF++w?17S|DA(aCo~i z4G}sqfhWQxylSl)O7MwLAMZFm0%Cr^``W{yhx-{bK_9ueZQn~Ex>{lKr zuM`=$9QSAHY>0zpVHm!4{7P&Tn+3XZ!LY8%78y_jel0{QM#O`x4oJ7eKY&ySpA-ssbAzXU$3%$Fzpf{JmykTj-0f}eBrPjxBNX?|;I9tQ7ES^lv*zPu#q=x0e;+;I z`U=I#KCm}sj;5Yi-2i!dqgyj+!a|qF*d*V=7;B2_jnotcz4Wk_^9DK%4>i;2u;omg z9aIVL-c4s-$40>d$Z0^__K_(R#j7?v>Ktxa?N9{_PaKWuh|IvvXAu5YpR>i)b&|lr zUuL1E9X>r~^iIOOZl>49w6g>#w~EQTtZC5ZXtfU8n2O9ZPsL}g2^c(!pEA&X3zFC9 zVobxk-4=(bSjtyT-CYW719{pEC6+6q1<8QQNsjFm>p|t&cghK5RbICI4cjgsmvQeJ zTJnfq4JrkoIfC8G3_w?2|e2lX8pVM=L!%t$8&CHv8OA2&B;=^|=@A4>Bs`8X^Vy0+ z)IPVXLm~LVMAp3ka?SesdLp05$Vir|p)lS@WmVM%=aaRidtn$D`&TV8q+4s+9+ELP%38pewaem5Djz`WfR^XMdbF6&f%PABaH434X7j0% zT3q~gf5O^lkt_A-4Fhi4Tfm1h|GW{8f^xT;^IyK7qUgMwyA2&VmC( zj$V5lZLMIr!B=@1rC@XLpzya7=mRNLykr6#IXU&EOcEjZnCt~A@%n5{UETHrf;U53 zD$&^sXiAfgPaNKXe(#8McwT(;CMYv`t7!5DucrOH*JV^`9+wnYFyOTpvotjJWda63 zSjiTf1+&>EAk31FRVaiwMMDTNzx09z#>Lc#z1Z91)&gL$^`E>^K1z?bu4{Opv zYvbmQlJq%3eQ?r>Sysw~41U#F zhzwb9)}3!yB4E$6m2UtxKK|icD1vJsx#(mPBnf`^|6|SdxDyA-!MA5QaOSG~dLzIV zv1{mr|Tg#f*U5;Fz#D+)i-&s3nW42L!6q&t+E5drlgcJ9sRQ2paOux!x4 zcVgDQ$vA%$etnGselijGfI|ghpDx$>e63AMTe65=hx9w6$HW)pE9I)FkFLNOckXzMKmG z-f)_IYT>h?^R%39@HSR*DE(J@i56l8bT0Y?P_x<|s-_8w@VTL#llkbBFP z&A&5w+zH}}X=WX6>ttj3f1RCoP*aH-#c`FYGz(x5G=S7#utX6Aq*~}r0g)CUEGz<% z7NjUeKuSPF7PG)20s=}=F?0keD$+xZfb`y^gc9I=_{M+U%$qm+?thuw8 zlF;*h_gzxepG$1$BB~W`_XqcR1jp3s{C#uqC}ufOKUz@1%+fOD$&-gJ-ma_74DM?A4zhqlGJ9)UdtA-cfxe%kG zqAs6sh56RgGdeu{fiKP3lNX74 z^Wbl_0mIKBW{Vdcm1Vw18zPi6S)F!^WDL>pVEZ8%S!$6~-^IdRi`9abO|MnXq{=}z z;d)2fc^kaOc9yyvd*hWYXLAU0`>-rXxTA9vZhzx%$3tUJ2rn**`9^_2Bn zXNW0=U2>}8U7ECvgF_lY99ajDYO`ZC)h7iEj^2nuAd;SIZoRE6$$v)C6M>7ZL`x%Y zxa5>^-*avsy?0)2YIV*WsLnaG|h6lCtp2m6bg6V*J?G#6w~_wm97I-3C)s z=a+aXvdcj3DZ}NJ6~ZOK*rDd`ZWDz$S#j}+u}0DkeOw8P#k$GagFAAjzRwanII`%C zux>icrk=KtlUMQhstEWKI1JzW_E;$AJ3Z*oJ_d#GlWGw1R74?sYh}^HAeQQ<-R9_!xx%1|adenJ`C74GHtwD% zlBN~tvkqG7Y!hBGP<{U_-r;D9ogy2OLeIBoFh_guB^c`z_*7de>#z>Czi};iEpaMA zlN#Q$6ErZF9ibDfH`+*@GRp*bLP#M1G!=1+y1~$3%I^2AQ(T#4lAoUgPm23AvZ~>1 z_^LJ02Otv|t{_uFZb7&pDPvoT=Qto+`a&Xy+bKfSZY27l4rBw`&P`KtN1+Z3 z<)FF(zlk}e6c%fy!#xP3dTLwn8v5-u=d4b;8OeV1Sv`BV|a+ z7C&S30pZa5P`z(hHA|sU54Kp=*CoB|5`I_BE)Y7S+Jv?Y6#OIW`o@VWN-G;uB)6@T z8>pJXAMS&ejfm;dhkjOYF6{K}2#07SK^{}e)da&#{HSkn>)&yu7YHo<7TcCY*w6^bL-=oA1?2#^oK4-lY!C zU6?e;I!H?rc`7=orO?$=0f*G(FIgpkeV|%Mcwav+Cy5eYF_BFpPKK6L$@kiVf&%+w zj@!&aTVcNO3I%HkytWLf=MW+8VX(UVJ;r*LQ`r{Z4a`NxOLNuOurJvKi;LC!#l-K0^w+QAc(a%0@?WE-=Dv}aQm(#o)OVmThVCo^K?6lmW5IarClP&9 zOq8;v>)LTYHL7h=S&YYz>H#J}M2AbC%1YF|+miA5`IM%Bq3`74Sqdves4)AqSV5}3(WNAOD{>Oz5SySe9hym^x= z`{GF*)9Z@vB`}`SRd1D63iLmlJ2ul^-Xm*Xa%TE`tk;WcCdBPzyz|2FnRYP%cxqbY z18WcTyycNn2?TF-uyXpXDfEG+&z`{Kl^gRJB1+kPKH2Dzbrs$rdn;^Meqo|zGN1f za+p(?%h=-5k)w>=haNU+@-LE!*zs{|hdU}^W0SX@USuRd6$c$d(HzIF%Oo6DDswOT z89sOR2Mn!2RYk8{H5p_N2RoR9e`I$sE4^M8d~uYY=e{Buh7C?`!z)tkb{R(QfWef8 z=ID~ zxLV>f$H(g}gx@sAJ2h0nI+rn@p&a-n+*q^VjKWm>fO{uMQVz^)1e)%J9aPm;jseO zf>~!>#m;~!r>s7|^PLvMg`Wf4vGcU6zZ*mYbabybWRQ&&^0@=98S>Pd$haKlf)Y41 zSBI)S$X-dap{kC2sIf<9EjM1@5>NQ@2#sseXu-{ym5bw{}k&d*)Q$h z+n$_*ac?Ei*W8rud}Q7F&b`w{kPvYD_t%OdIl0`kjABzsaQCm_R#N5mx48j!&uNH* zV|$@{Zgt0h5f}X+dGEJG`kSI6VU>>7aAH26BrrD8C~-f)5mc9RfYn*f{&OoxEVK1o zyIZ?fy+Y+i6(qYK?xzUbzM*9oLmU8@PU@xvzF9=JdHE<_g_*hoKB)T!59CO2kBHCI zJWefQ65HnQPm~GRh@lpRR_$q36(wD05D@3^JpgA02< z5azUD#B9kanL36Tk{_wN{q2vsmYz46(S==W{m)6tyHxA4`lq?VFC4Y)$!0QFx#C%& zq01XRWy8SmkSCvWKf`@oP-^jKKr=%fhC95f7#L)}GE_qS1k3+|V}HS}0C@jj>;Iy0 z{~LV!pYQ)`eaNVPZm&#EZ^MgpJ>Be$E$g52=~!Fa<)HzJK{FqU6P~{$^;IK3*=cLH ziD8oZiSx>sKfs`VzVx4N@}I8%YZrg<#C$q*rPR-l8~nVJx^{Qqn!1i!u8LLQzW|ZK BjY= 1.6 with access configured to it using +[kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, +you may setup a test cluster on your local machine using +[minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). + * We recommend using the latest release of minikube with the DNS addon enabled. +* You must have appropriate permissions to list, create, edit and delete +[pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources +by running `kubectl auth can-i pods`. + * The service account credentials used by the driver pods must be allowed to create pods, services and configmaps. +* You must have [Kubernetes DNS](https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/) configured in your cluster. + +# How it works + +

    + Spark cluster components +

    + +spark-submit can be directly used to submit a Spark application to a Kubernetes cluster. +The submission mechanism works as follows: + +* Spark creates a Spark driver running within a [Kubernetes pod](https://kubernetes.io/docs/concepts/workloads/pods/pod/). +* The driver creates executors which are also running within Kubernetes pods and connects to them, and executes application code. +* When the application completes, the executor pods terminate and are cleaned up, but the driver pod persists +logs and remains in "completed" state in the Kubernetes API until it's eventually garbage collected or manually cleaned up. + +Note that in the completed state, the driver pod does *not* use any computational or memory resources. + +The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling +decisions for driver and executor pods using advanced primitives like +[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) +in a future release. + +# Submitting Applications to Kubernetes + +## Docker Images + +Kubernetes requires users to supply images that can be deployed into containers within pods. The images are built to +be run in a container runtime environment that Kubernetes supports. Docker is a container runtime environment that is +frequently used with Kubernetes. With Spark 2.3, there are Dockerfiles provided in the runnable distribution that can be customized +and built for your usage. + +You may build these docker images from sources. +There is a script, `sbin/build-push-docker-images.sh` that you can use to build and push +customized Spark distribution images consisting of all the above components. + +Example usage is: + + ./sbin/build-push-docker-images.sh -r -t my-tag build + ./sbin/build-push-docker-images.sh -r -t my-tag push + +Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before +building using the supplied script, or manually. + +## Cluster Mode + +To launch Spark Pi in cluster mode, + +{% highlight bash %} +$ bin/spark-submit \ + --master k8s://https://: \ + --deploy-mode cluster \ + --name spark-pi \ + --class org.apache.spark.examples.SparkPi \ + --conf spark.executor.instances=5 \ + --conf spark.kubernetes.driver.docker.image= \ + --conf spark.kubernetes.executor.docker.image= \ + local:///path/to/examples.jar +{% endhighlight %} + +The Spark master, specified either via passing the `--master` command line argument to `spark-submit` or by setting +`spark.master` in the application's configuration, must be a URL with the format `k8s://`. Prefixing the +master string with `k8s://` will cause the Spark application to launch on the Kubernetes cluster, with the API server +being contacted at `api_server_url`. If no HTTP protocol is specified in the URL, it defaults to `https`. For example, +setting the master to `k8s://example.com:443` is equivalent to setting it to `k8s://https://example.com:443`, but to +connect without TLS on a different port, the master would be set to `k8s://http://example.com:8080`. + +In Kubernetes mode, the Spark application name that is specified by `spark.app.name` or the `--name` argument to +`spark-submit` is used by default to name the Kubernetes resources created like drivers and executors. So, application names +must consist of lower case alphanumeric characters, `-`, and `.` and must start and end with an alphanumeric character. + +If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. + +```bash +kubectl cluster-info +Kubernetes master is running at http://127.0.0.1:6443 +``` + +In the above example, the specific Kubernetes cluster can be used with spark-submit by specifying +`--master k8s://http://127.0.0.1:6443` as an argument to spark-submit. Additionally, it is also possible to use the +authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. + +The local proxy can be started by: + +```bash +kubectl proxy +``` + +If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to +spark-submit. Finally, notice that in the above example we specify a jar with a specific URI with a scheme of `local://`. +This URI is the location of the example jar that is already in the Docker image. + +## Dependency Management + +If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to +by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. +Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. + +## Introspection and Debugging + +These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and +take actions. + +### Accessing Logs + +Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spark application is running, it's possible +to stream logs from the application using: + +```bash +kubectl -n= logs -f +``` + +The same logs can also be accessed through the +[Kubernetes dashboard](https://kubernetes.io/docs/tasks/access-application-cluster/web-ui-dashboard/) if installed on +the cluster. + +### Accessing Driver UI + +The UI associated with any application can be accessed locally using +[`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). + +```bash +kubectl port-forward 4040:4040 +``` + +Then, the Spark driver UI can be accessed on `http://localhost:4040`. + +### Debugging + +There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the +connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there +are errors during the running of the application, often, the best way to investigate may be through the Kubernetes CLI. + +To get some basic information about the scheduling decisions made around the driver pod, you can run: + +```bash +kubectl describe pod +``` + +If the pod has encountered a runtime error, the status can be probed further using: + +```bash +kubectl logs +``` + +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +the Spark application. + +## Kubernetes Features + +### Namespaces + +Kubernetes has the concept of [namespaces](https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/). +Namespaces are ways to divide cluster resources between multiple users (via resource quota). Spark on Kubernetes can +use namespaces to launch Spark applications. This can be made use of through the `spark.kubernetes.namespace` configuration. + +Kubernetes allows using [ResourceQuota](https://kubernetes.io/docs/concepts/policy/resource-quotas/) to set limits on +resources, number of objects, etc on individual namespaces. Namespaces and ResourceQuota can be used in combination by +administrator to control sharing and resource allocation in a Kubernetes cluster running Spark applications. + +### RBAC + +In Kubernetes clusters with [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) enabled, users can configure +Kubernetes RBAC roles and service accounts used by the various Spark on Kubernetes components to access the Kubernetes +API server. + +The Spark driver pod uses a Kubernetes service account to access the Kubernetes API server to create and watch executor +pods. The service account used by the driver pod must have the appropriate permission for the driver to be able to do +its work. Specifically, at minimum, the service account must be granted a +[`Role` or `ClusterRole`](https://kubernetes.io/docs/admin/authorization/rbac/#role-and-clusterrole) that allows driver +pods to create pods and services. By default, the driver pod is automatically assigned the `default` service account in +the namespace specified by `spark.kubernetes.namespace`, if no service account is specified when the pod gets created. + +Depending on the version and setup of Kubernetes deployed, this `default` service account may or may not have the role +that allows driver pods to create pods and services under the default Kubernetes +[RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom +service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to +be used by the driver pod through the configuration property +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +use the `spark` service account, a user simply adds the following option to the `spark-submit` command: + +``` +--conf spark.kubernetes.authenticate.driver.serviceAccountName=spark +``` + +To create a custom service account, a user can use the `kubectl create serviceaccount` command. For example, the +following command creates a service account named `spark`: + +```bash +kubectl create serviceaccount spark +``` + +To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create +a `RoleBinding` or `ClusterRoleBinding`, a user can use the `kubectl create rolebinding` (or `clusterrolebinding` +for `ClusterRoleBinding`) command. For example, the following command creates an `edit` `ClusterRole` in the `default` +namespace and grants it to the `spark` service account created above: + +```bash +kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +``` + +Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a +`ClusterRole` can be used to grant access to cluster-scoped resources (like nodes) as well as namespaced resources +(like pods) across all namespaces. For Spark on Kubernetes, since the driver always creates executor pods in the +same namespace, a `Role` is sufficient, although users may use a `ClusterRole` instead. For more information on +RBAC authorization and how to configure Kubernetes service accounts for pods, please refer to +[Using RBAC Authorization](https://kubernetes.io/docs/admin/authorization/rbac/) and +[Configure Service Accounts for Pods](https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/). + +## Client Mode + +Client mode is not currently supported. + +## Future Work + +There are several Spark on Kubernetes features that are currently being incubated in a fork - +[apache-spark-on-k8s/spark](https://github.com/apache-spark-on-k8s/spark), which are expected to eventually make it into +future versions of the spark-kubernetes integration. + +Some of these include: + +* PySpark +* R +* Dynamic Executor Scaling +* Local File Dependency Management +* Spark Application Management +* Job Queues and Resource Management + +You can refer to the [documentation](https://apache-spark-on-k8s.github.io/userdocs/) if you want to try these features +and provide feedback to the development team. + +# Configuration + +See the [configuration page](configuration.html) for information on Spark configurations. The following configurations are +specific to Spark on Kubernetes. + +#### Spark Properties + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.kubernetes.namespacedefault + The namespace that will be used for running the driver and executor pods. +
    spark.kubernetes.driver.container.image(none) + Container image to use for the driver. + This is usually of the form `example.com/repo/spark-driver:v1.0.0`. + This configuration is required and must be provided by the user. +
    spark.kubernetes.executor.container.image(none) + Container image to use for the executors. + This is usually of the form `example.com/repo/spark-executor:v1.0.0`. + This configuration is required and must be provided by the user. +
    spark.kubernetes.container.image.pullPolicyIfNotPresent + Container image pull policy used when pulling images within Kubernetes. +
    spark.kubernetes.allocation.batch.size5 + Number of pods to launch at once in each round of executor pod allocation. +
    spark.kubernetes.allocation.batch.delay1s + Time to wait between each round of executor pod allocation. Specifying values less than 1 second may lead to + excessive CPU usage on the spark driver. +
    spark.kubernetes.authenticate.submission.caCertFile(none) + Path to the CA cert file for connecting to the Kubernetes API server over TLS when starting the driver. This file + must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide + a scheme). +
    spark.kubernetes.authenticate.submission.clientKeyFile(none) + Path to the client key file for authenticating against the Kubernetes API server when starting the driver. This file + must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide + a scheme). +
    spark.kubernetes.authenticate.submission.clientCertFile(none) + Path to the client cert file for authenticating against the Kubernetes API server when starting the driver. This + file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not + provide a scheme). +
    spark.kubernetes.authenticate.submission.oauthToken(none) + OAuth token to use when authenticating against the Kubernetes API server when starting the driver. Note + that unlike the other authentication options, this is expected to be the exact string value of the token to use for + the authentication. +
    spark.kubernetes.authenticate.submission.oauthTokenFile(none) + Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server when starting the driver. + This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not + provide a scheme). +
    spark.kubernetes.authenticate.driver.caCertFile(none) + Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting + executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). +
    spark.kubernetes.authenticate.driver.clientKeyFile(none) + Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting + executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). If this is specified, it is highly + recommended to set up TLS for the driver submission server, as this value is sensitive information that would be + passed to the driver pod in plaintext otherwise. +
    spark.kubernetes.authenticate.driver.clientCertFile(none) + Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when + requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the + driver pod. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). +
    spark.kubernetes.authenticate.driver.oauthToken(none) + OAuth token to use when authenticating against the Kubernetes API server from the driver pod when + requesting executors. Note that unlike the other authentication options, this must be the exact string value of + the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is + highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would + be passed to the driver pod in plaintext otherwise. +
    spark.kubernetes.authenticate.driver.oauthTokenFile(none) + Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server from the driver pod when + requesting executors. Note that unlike the other authentication options, this file must contain the exact string value of + the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is + highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would + be passed to the driver pod in plaintext otherwise. +
    spark.kubernetes.authenticate.driver.mounted.caCertFile(none) + Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting + executors. This path must be accessible from the driver pod. + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). +
    spark.kubernetes.authenticate.driver.mounted.clientKeyFile(none) + Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting + executors. This path must be accessible from the driver pod. + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). +
    spark.kubernetes.authenticate.driver.mounted.clientCertFile(none) + Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when + requesting executors. This path must be accessible from the driver pod. + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). +
    spark.kubernetes.authenticate.driver.mounted.oauthTokenFile(none) + Path to the file containing the OAuth token to use when authenticating against the Kubernetes API server from the driver pod when + requesting executors. This path must be accessible from the driver pod. + Note that unlike the other authentication options, this file must contain the exact string value of the token to use for the authentication. +
    spark.kubernetes.authenticate.driver.serviceAccountNamedefault + Service account that is used when running the driver pod. The driver pod uses this service account when requesting + executor pods from the API server. Note that this cannot be specified alongside a CA cert file, client key file, + client cert file, and/or OAuth token. +
    spark.kubernetes.driver.label.[LabelName](none) + Add the label specified by LabelName to the driver pod. + For example, spark.kubernetes.driver.label.something=true. + Note that Spark also adds its own labels to the driver pod + for bookkeeping purposes. +
    spark.kubernetes.driver.annotation.[AnnotationName](none) + Add the annotation specified by AnnotationName to the driver pod. + For example, spark.kubernetes.driver.annotation.something=true. +
    spark.kubernetes.executor.label.[LabelName](none) + Add the label specified by LabelName to the executor pods. + For example, spark.kubernetes.executor.label.something=true. + Note that Spark also adds its own labels to the driver pod + for bookkeeping purposes. +
    spark.kubernetes.executor.annotation.[AnnotationName](none) + Add the annotation specified by AnnotationName to the executor pods. + For example, spark.kubernetes.executor.annotation.something=true. +
    spark.kubernetes.driver.pod.name(none) + Name of the driver pod. If not set, the driver pod name is set to "spark.app.name" suffixed by the current timestamp + to avoid name conflicts. +
    spark.kubernetes.executor.podNamePrefix(none) + Prefix for naming the executor pods. + If not set, the executor pod name is set to driver pod name suffixed by an integer. +
    spark.kubernetes.executor.lostCheck.maxAttempts10 + Number of times that the driver will try to ascertain the loss reason for a specific executor. + The loss reason is used to ascertain whether the executor failure is due to a framework or an application error + which in turn decides whether the executor is removed and replaced, or placed into a failed state for debugging. +
    spark.kubernetes.submission.waitAppCompletiontrue + In cluster mode, whether to wait for the application to finish before exiting the launcher process. When changed to + false, the launcher has a "fire-and-forget" behavior when launching the Spark job. +
    spark.kubernetes.report.interval1s + Interval between reports of the current Spark job status in cluster mode. +
    spark.kubernetes.driver.limit.cores(none) + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. +
    spark.kubernetes.executor.limit.cores(none) + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. +
    spark.kubernetes.node.selector.[labelKey](none) + Adds to the node selector of the driver pod and executor pods, with key labelKey and the value as the + configuration's value. For example, setting spark.kubernetes.node.selector.identifier to myIdentifier + will result in the driver pod and executors having a node selector with key identifier and value + myIdentifier. Multiple node selector keys can be added by setting multiple configurations with this prefix. +
    spark.kubernetes.driverEnv.[EnvironmentVariableName](none) + Add the environment variable specified by EnvironmentVariableName to + the Driver process. The user can specify multiple of these to set multiple environment variables. +
    spark.kubernetes.mountDependencies.jarsDownloadDir/var/spark-data/spark-jars + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. +
    spark.kubernetes.mountDependencies.filesDownloadDir/var/spark-data/spark-files + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. +
    \ No newline at end of file diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 7e2386f33b583..e7edec5990363 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -18,7 +18,9 @@ Spark application's configuration (driver, executors, and the AM when running in There are two deploy modes that can be used to launch Spark applications on YARN. In `cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn`. +Unlike other cluster managers supported by Spark in which the master's address is specified in the `--master` +parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. +Thus, the `--master` parameter is `yarn`. To launch a Spark application in `cluster` mode: diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 866d6e527549c..0473ab73a5e6c 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -127,6 +127,16 @@ export HADOOP_CONF_DIR=XXX http://path/to/examples.jar \ 1000 +# Run on a Kubernetes cluster in cluster deploy mode +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master k8s://xx.yy.zz.ww:443 \ + --deploy-mode cluster \ + --executor-memory 20G \ + --num-executors 50 \ + http://path/to/examples.jar \ + 1000 + {% endhighlight %} # Master URLs @@ -155,6 +165,12 @@ The master URL passed to Spark can be in one of the following formats: client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. + k8s://HOST:PORT Connect to a Kubernetes cluster in + cluster mode. Client mode is currently unsupported and will be supported in future releases. + The HOST and PORT refer to the [Kubernetes API Server](https://kubernetes.io/docs/reference/generated/kube-apiserver/). + It connects using TLS by default. In order to force it to use an unsecured connection, you can use + k8s://http://HOST:PORT. + diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh new file mode 100755 index 0000000000000..4546e98dc2074 --- /dev/null +++ b/sbin/build-push-docker-images.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script builds and pushes docker images when run from a release of Spark +# with Kubernetes support. + +declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ + [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile ) + +function build { + docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + for image in "${!path[@]}"; do + docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + done +} + + +function push { + for image in "${!path[@]}"; do + docker push ${REPO}/$image:${TAG} + done +} + +function usage { + echo "This script must be run from a runnable distribution of Apache Spark." + echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" + echo " ./sbin/build-push-docker-images.sh -r -t push" + echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage + exit 0 +fi + +while getopts r:t: option +do + case "${option}" + in + r) REPO=${OPTARG};; + t) TAG=${OPTARG};; + esac +done + +if [ -z "$REPO" ] || [ -z "$TAG" ]; then + usage +else + case "${@: -1}" in + build) build;; + push) push;; + *) usage;; + esac +fi From c0abb1d994bda50d964c555163cdfca5a7e56f64 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 22 Dec 2017 09:25:39 +0800 Subject: [PATCH 481/712] [SPARK-22854][UI] Read Spark version from event logs. The code was ignoring SparkListenerLogStart, which was added somewhat recently to record the Spark version used to generate an event log. Author: Marcelo Vanzin Closes #20049 from vanzin/SPARK-22854. --- .../scala/org/apache/spark/status/AppStatusListener.scala | 7 ++++++- .../org/apache/spark/status/AppStatusListenerSuite.scala | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 4db797e1d24c6..5253297137323 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -48,7 +48,7 @@ private[spark] class AppStatusListener( import config._ - private val sparkVersion = SPARK_VERSION + private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var appSummary = new AppSummary(0, 0) private var coresPerTask: Int = 1 @@ -90,6 +90,11 @@ private[spark] class AppStatusListener( } } + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(version) => sparkVersion = version + case _ => + } + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { assert(event.appId.isDefined, "Application without IDs are not supported.") diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 9cf4f7efb24a8..c0b3a79fe981e 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -103,6 +103,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { test("scheduler events") { val listener = new AppStatusListener(store, conf, true) + listener.onOtherEvent(SparkListenerLogStart("TestSparkVersion")) + // Start the application. time += 1 listener.onApplicationStart(SparkListenerApplicationStart( @@ -125,6 +127,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(attempt.endTime.getTime() === -1L) assert(attempt.sparkUser === "user") assert(!attempt.completed) + assert(attempt.appSparkVersion === "TestSparkVersion") } // Start a couple of executors. From c6f01cadede490bde987d067becef14442f1e4a1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 22 Dec 2017 10:13:26 +0800 Subject: [PATCH 482/712] [SPARK-22750][SQL] Reuse mutable states when possible ## What changes were proposed in this pull request? The PR introduces a new method `addImmutableStateIfNotExists ` to `CodeGenerator` to allow reusing and sharing the same global variable between different Expressions. This helps reducing the number of global variables needed, which is important to limit the impact on the constant pool. ## How was this patch tested? added UTs Author: Marco Gaido Author: Marco Gaido Closes #19940 from mgaido91/SPARK-22750. --- .../MonotonicallyIncreasingID.scala | 3 +- .../expressions/SparkPartitionID.scala | 3 +- .../expressions/codegen/CodeGenerator.scala | 40 +++++++++++++++++++ .../expressions/datetimeExpressions.scala | 12 ++++-- .../expressions/objects/objects.scala | 24 +++++++---- .../expressions/CodeGenerationSuite.scala | 12 ++++++ 6 files changed, 80 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 784eaf8195194..11fb579dfa88c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -66,7 +66,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") - val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask") + val partitionMaskTerm = "partitionMask" + ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 736ca37c6d54a..a160b9b275290 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -43,7 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId") + val idTerm = "partitionId" + ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9adf632ddcde8..d6eccadcfb63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -207,6 +207,14 @@ class CodegenContext { } + /** + * A map containing the mutable states which have been defined so far using + * `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as key and + * its Java type and init code as value. + */ + private val immutableStates: mutable.Map[String, (String, String)] = + mutable.Map.empty[String, (String, String)] + /** * Add a mutable state as a field to the generated class. c.f. the comments above. * @@ -265,6 +273,38 @@ class CodegenContext { } } + /** + * Add an immutable state as a field to the generated class only if it does not exist yet a field + * with that name. This helps reducing the number of the generated class' fields, since the same + * variable can be reused by many functions. + * + * Even though the added variables are not declared as final, they should never be reassigned in + * the generated code to prevent errors and unexpected behaviors. + * + * Internally, this method calls `addMutableState`. + * + * @param javaType Java type of the field. + * @param variableName Name of the field. + * @param initFunc Function includes statement(s) to put into the init() method to initialize + * this field. The argument is the name of the mutable state variable. + */ + def addImmutableStateIfNotExists( + javaType: String, + variableName: String, + initFunc: String => String = _ => ""): Unit = { + val existingImmutableState = immutableStates.get(variableName) + if (existingImmutableState.isEmpty) { + addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline = true) + immutableStates(variableName) = (javaType, initFunc(variableName)) + } else { + val (prevJavaType, prevInitCode) = existingImmutableState.get + assert(prevJavaType == javaType, s"$variableName has already been defined with type " + + s"$prevJavaType and now it is tried to define again with type $javaType.") + assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined " + + s"with different initialization statements.") + } + } + /** * Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees * that the variable is safely stored, which is important for (potentially) byte array backed diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 59c3e3d9947a3..7a674ea7f4d76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -443,7 +443,8 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = ctx.addMutableState(cal, "cal", + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); @@ -484,8 +485,9 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName + val c = "calWeekOfYear" val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = ctx.addMutableState(cal, "cal", v => + ctx.addImmutableStateIfNotExists(cal, c, v => s""" |$v = $cal.getInstance($dtu.getTimeZone("UTC")); |$v.setFirstDayOfWeek($cal.MONDAY); @@ -1017,7 +1019,8 @@ case class FromUTCTimestamp(left: Expression, right: Expression) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val tzTerm = ctx.addMutableState(tzClass, "tz", v => s"""$v = $dtu.getTimeZone("$tz");""") - val utcTerm = ctx.addMutableState(tzClass, "utc", + val utcTerm = "tzUTC" + ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" @@ -1193,7 +1196,8 @@ case class ToUTCTimestamp(left: Expression, right: Expression) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val tzTerm = ctx.addMutableState(tzClass, "tz", v => s"""$v = $dtu.getTimeZone("$tz");""") - val utcTerm = ctx.addMutableState(tzClass, "utc", + val utcTerm = "tzUTC" + ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a59aad5be8715..4af813456b790 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1148,17 +1148,21 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val (serializerClass, serializerInstanceClass) = { + val (serializer, serializerClass, serializerInstanceClass) = { if (kryo) { - (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + ("kryoSerializer", + classOf[KryoSerializer].getName, + classOf[KryoSerializerInstance].getName) } else { - (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + ("javaSerializer", + classOf[JavaSerializer].getName, + classOf[JavaSerializerInstance].getName) } } // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode", v => + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => s""" |if ($env == null) { | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); @@ -1193,17 +1197,21 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. - val (serializerClass, serializerInstanceClass) = { + val (serializer, serializerClass, serializerInstanceClass) = { if (kryo) { - (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + ("kryoSerializer", + classOf[KryoSerializer].getName, + classOf[KryoSerializerInstance].getName) } else { - (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + ("javaSerializer", + classOf[JavaSerializer].getName, + classOf[JavaSerializerInstance].getName) } } // try conf from env, otherwise create a new one val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode", v => + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => s""" |if ($env == null) { | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b1a44528e64d7..676ba3956ddc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -424,4 +424,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10) assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10) } + + test("SPARK-22750: addImmutableStateIfNotExists") { + val ctx = new CodegenContext + val mutableState1 = "field1" + val mutableState2 = "field2" + ctx.addImmutableStateIfNotExists("int", mutableState1) + ctx.addImmutableStateIfNotExists("int", mutableState1) + ctx.addImmutableStateIfNotExists("String", mutableState2) + ctx.addImmutableStateIfNotExists("int", mutableState1) + ctx.addImmutableStateIfNotExists("String", mutableState2) + assert(ctx.inlinedMutableStates.length == 2) + } } From a36b78b0e420b909bde0cec4349cdc2103853b91 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 21 Dec 2017 20:20:04 -0600 Subject: [PATCH 483/712] [SPARK-22450][CORE][MLLIB][FOLLOWUP] safely register class for mllib - LabeledPoint/VectorWithNorm/TreePoint ## What changes were proposed in this pull request? register following classes in Kryo: `org.apache.spark.mllib.regression.LabeledPoint` `org.apache.spark.mllib.clustering.VectorWithNorm` `org.apache.spark.ml.feature.LabeledPoint` `org.apache.spark.ml.tree.impl.TreePoint` `org.apache.spark.ml.tree.impl.BaggedPoint` seems also need to be registered, but I don't know how to do it in this safe way. WeichenXu123 cloud-fan ## How was this patch tested? added tests Author: Zheng RuiFeng Closes #19950 from zhengruifeng/labeled_kryo. --- .../spark/serializer/KryoSerializer.scala | 27 +++++++------ ...InstanceSuit.scala => InstanceSuite.scala} | 24 ++++++------ .../spark/ml/feature/LabeledPointSuite.scala | 39 +++++++++++++++++++ .../spark/ml/tree/impl/TreePointSuite.scala | 35 +++++++++++++++++ .../spark/mllib/clustering/KMeansSuite.scala | 18 ++++++++- .../mllib/regression/LabeledPointSuite.scala | 18 ++++++++- 6 files changed, 135 insertions(+), 26 deletions(-) rename mllib/src/test/scala/org/apache/spark/ml/feature/{InstanceSuit.scala => InstanceSuite.scala} (79%) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 2259d1a2d555d..538ae05e4eea1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -181,20 +181,25 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. - Seq("org.apache.spark.mllib.linalg.Vector", - "org.apache.spark.mllib.linalg.DenseVector", - "org.apache.spark.mllib.linalg.SparseVector", - "org.apache.spark.mllib.linalg.Matrix", - "org.apache.spark.mllib.linalg.DenseMatrix", - "org.apache.spark.mllib.linalg.SparseMatrix", - "org.apache.spark.ml.linalg.Vector", + Seq( + "org.apache.spark.ml.feature.Instance", + "org.apache.spark.ml.feature.LabeledPoint", + "org.apache.spark.ml.feature.OffsetInstance", + "org.apache.spark.ml.linalg.DenseMatrix", "org.apache.spark.ml.linalg.DenseVector", - "org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.Matrix", - "org.apache.spark.ml.linalg.DenseMatrix", "org.apache.spark.ml.linalg.SparseMatrix", - "org.apache.spark.ml.feature.Instance", - "org.apache.spark.ml.feature.OffsetInstance" + "org.apache.spark.ml.linalg.SparseVector", + "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.tree.impl.TreePoint", + "org.apache.spark.mllib.clustering.VectorWithNorm", + "org.apache.spark.mllib.linalg.DenseMatrix", + "org.apache.spark.mllib.linalg.DenseVector", + "org.apache.spark.mllib.linalg.Matrix", + "org.apache.spark.mllib.linalg.SparseMatrix", + "org.apache.spark.mllib.linalg.SparseVector", + "org.apache.spark.mllib.linalg.Vector", + "org.apache.spark.mllib.regression.LabeledPoint" ).foreach { name => try { val clazz = Utils.classForName(name) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala similarity index 79% rename from mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala rename to mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala index 88c85a9425e78..cca7399b4b9c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala @@ -17,31 +17,29 @@ package org.apache.spark.ml.feature -import scala.reflect.ClassTag - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.serializer.KryoSerializer -class InstanceSuit extends SparkFunSuite{ +class InstanceSuite extends SparkFunSuite{ test("Kryo class register") { val conf = new SparkConf(false) conf.set("spark.kryo.registrationRequired", "true") - val ser = new KryoSerializer(conf) - val serInstance = new KryoSerializer(conf).newInstance() - - def check[T: ClassTag](t: T) { - assert(serInstance.deserialize[T](serInstance.serialize(t)) === t) - } + val ser = new KryoSerializer(conf).newInstance() val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)) val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse) + Seq(instance1, instance2).foreach { i => + val i2 = ser.deserialize[Instance](ser.serialize(i)) + assert(i === i2) + } + val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)) val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse) - check(instance1) - check(instance2) - check(oInstance1) - check(oInstance2) + Seq(oInstance1, oInstance2).foreach { o => + val o2 = ser.deserialize[OffsetInstance](ser.serialize(o)) + assert(o === o2) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala new file mode 100644 index 0000000000000..05c7a58ee5ffd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.serializer.KryoSerializer + +class LabeledPointSuite extends SparkFunSuite { + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0))) + val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0))) + + Seq(labeled1, labeled2).foreach { l => + val l2 = ser.deserialize[LabeledPoint](ser.serialize(l)) + assert(l === l2) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala new file mode 100644 index 0000000000000..f41abe48f2c58 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoSerializer + +class TreePointSuite extends SparkFunSuite { + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val point = new TreePoint(1.0, Array(1, 2, 3)) + val point2 = ser.deserialize[TreePoint](ser.serialize(point)) + assert(point.label === point2.label) + assert(point.binnedFeatures === point2.binnedFeatures) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 48bd41dc3e3bf..00d7e2f2d3864 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -311,6 +312,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) } + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val vec1 = new VectorWithNorm(Vectors.dense(Array(1.0, 2.0))) + val vec2 = new VectorWithNorm(Vectors.sparse(10, Array(5, 8), Array(1.0, 2.0))) + + Seq(vec1, vec2).foreach { v => + val v2 = ser.deserialize[VectorWithNorm](ser.serialize(v)) + assert(v2.norm === v.norm) + assert(v2.vector === v.vector) + } + } } object KMeansSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 252a068dcd72f..c1449ece740d4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.serializer.KryoSerializer class LabeledPointSuite extends SparkFunSuite { @@ -53,4 +54,19 @@ class LabeledPointSuite extends SparkFunSuite { assert(p1 === LabeledPoint.fromML(p2)) } } + + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0))) + val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), Array(1.0, 2.0))) + + Seq(labeled1, labeled2).foreach { l => + val l2 = ser.deserialize[LabeledPoint](ser.serialize(l)) + assert(l === l2) + } + } } From 22e1849bcfb3ef988f4f9a5c2783bfc7ec001694 Mon Sep 17 00:00:00 2001 From: Anirudh Ramanathan Date: Thu, 21 Dec 2017 21:03:10 -0800 Subject: [PATCH 484/712] [SPARK-22866][K8S] Fix path issue in Kubernetes dockerfile ## What changes were proposed in this pull request? The path was recently changed in https://github.com/apache/spark/pull/19946, but the dockerfile was not updated. This is a trivial 1 line fix. ## How was this patch tested? `./sbin/build-push-docker-images.sh -r spark-repo -t latest build` cc/ vanzin mridulm rxin jiangxb1987 liyinan926 Author: Anirudh Ramanathan Author: foxish Closes #20051 from foxish/patch-1. --- .../kubernetes/docker/src/main/dockerfiles/driver/Dockerfile | 2 +- .../docker/src/main/dockerfiles/executor/Dockerfile | 2 +- .../docker/src/main/dockerfiles/spark-base/Dockerfile | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index d16349559466d..9b682f8673c69 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f dockerfiles/spark-base/Dockerfile . +# docker build -t spark-driver:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . COPY examples /opt/spark/examples diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0e38169b8efdc..168cd4cb6c57a 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f dockerfiles/spark-base/Dockerfile . +# docker build -t spark-executor:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . COPY examples /opt/spark/examples diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 20316c9c5098a..222e777db3a82 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -22,7 +22,7 @@ FROM openjdk:8-alpine # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f dockerfiles/spark-base/Dockerfile . +# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -38,7 +38,7 @@ COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY dockerfiles/spark-base/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark From 8df1da396f64bb7fe76d73cd01498fdf3b8ed964 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 21 Dec 2017 21:38:16 -0800 Subject: [PATCH 485/712] [SPARK-22862] Docs on lazy elimination of columns missing from an encoder This behavior has confused some users, so lets clarify it. Author: Michael Armbrust Closes #20048 from marmbrus/datasetAsDocs. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef00562672a7e..209b800fdc6f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -401,6 +401,10 @@ class Dataset[T] private[sql]( * If the schema of the Dataset does not match the desired `U` type, you can use `select` * along with `alias` or `as` to rearrange or rename as required. * + * Note that `as[]` only changes the view of the data that is passed into typed operations, + * such as `map()`, and does not eagerly project away any columns that are not present in + * the specified class. + * * @group basic * @since 1.6.0 */ From 13190a4f60c081a68812df6df1d8262779cd6fcb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 22 Dec 2017 20:09:51 +0900 Subject: [PATCH 486/712] [SPARK-22874][PYSPARK][SQL] Modify checking pandas version to use LooseVersion. ## What changes were proposed in this pull request? Currently we check pandas version by capturing if `ImportError` for the specific imports is raised or not but we can compare `LooseVersion` of the version strings as the same as we're checking pyarrow version. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20054 from ueshin/issues/SPARK-22874. --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/session.py | 15 +++++++-------- python/pyspark/sql/tests.py | 7 ++++--- python/pyspark/sql/types.py | 33 +++++++++++++-------------------- python/pyspark/sql/udf.py | 4 ++-- python/pyspark/sql/utils.py | 11 ++++++++++- 6 files changed, 38 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 440684d3edfa6..95eca76fa9888 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1906,9 +1906,9 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 86db16eca7889..6e5eec48e8aca 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -493,15 +493,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): data types will be used to coerce the data in Pandas to Arrow conversion. """ from pyspark.serializers import ArrowSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ - _old_pandas_exception_message, TimestampType - from pyspark.sql.utils import _require_minimum_pyarrow_version - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pyspark.sql.utils import require_minimum_pandas_version, \ + require_minimum_pyarrow_version + + require_minimum_pandas_version() + require_minimum_pyarrow_version() - _require_minimum_pyarrow_version() + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6fdfda1cc831b..b977160af566d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -53,7 +53,8 @@ try: import pandas try: - import pandas.api + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() _have_pandas = True except: _have_old_pandas = True @@ -2600,7 +2601,7 @@ def test_to_pandas(self): @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") def test_to_pandas_old(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() @unittest.skipIf(not _have_pandas, "Pandas not installed") @@ -2643,7 +2644,7 @@ def test_create_dataframe_from_old_pandas(self): pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}) with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self.spark.createDataFrame(pdf) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 46d9a417414b5..063264a89379c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1678,13 +1678,6 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _old_pandas_exception_message(e): - """ Create an error message for importing old Pandas. - """ - msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process" - return "%s\n%s" % (_exception_message(e), msg) - - def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1693,10 +1686,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive """ - try: - from pandas.api.types import is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? @@ -1714,10 +1707,10 @@ def _check_series_convert_timestamps_internal(s, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone """ - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): tz = timezone or 'tzlocal()' @@ -1737,11 +1730,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): :param to_timezone: the timezone to convert to. if None then use local timezone :return pandas.Series where if it is a timestamp, has been converted to tz-naive """ - try: - import pandas as pd - from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype from_tz = from_timezone or 'tzlocal()' to_tz = to_timezone or 'tzlocal()' # TODO: handle nested timestamps, such as ArrayType(TimestampType())? diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 50c87ba1ac882..123138117fdc3 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _create_udf(f, returnType, evalType): if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: import inspect - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() argspec = inspect.getargspec(f) if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cc7dabb64b3ec..fb7d42a35d8f4 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -112,7 +112,16 @@ def toJArray(gateway, jtype, arr): return jarr -def _require_minimum_pyarrow_version(): +def require_minimum_pandas_version(): + """ Raise ImportError if minimum version of Pandas is not installed + """ + from distutils.version import LooseVersion + import pandas + if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + + +def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ from distutils.version import LooseVersion From d23dc5b8ef6c6aee0a31a304eefeb6ddb1c26c0f Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 22 Dec 2017 14:05:57 -0800 Subject: [PATCH 487/712] [SPARK-22346][ML] VectorSizeHint Transformer for using VectorAssembler in StructuredSteaming ## What changes were proposed in this pull request? A new VectorSizeHint transformer was added. This transformer is meant to be used as a pipeline stage ahead of VectorAssembler, on vector columns, so that VectorAssembler can join vectors in a streaming context where the size of the input vectors is otherwise not known. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #19746 from MrBago/vector-size-hint. --- .../spark/ml/feature/VectorSizeHint.scala | 195 ++++++++++++++++++ .../ml/feature/VectorSizeHintSuite.scala | 189 +++++++++++++++++ 2 files changed, 384 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala new file mode 100644 index 0000000000000..1fe3cfc74c76d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * A feature transformer that adds size information to the metadata of a vector column. + * VectorAssembler needs size information for its input columns and cannot be used on streaming + * dataframes without this metadata. + * + */ +@Experimental +@Since("2.3.0") +class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Transformer with HasInputCol with HasHandleInvalid with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("vectSizeHint")) + + /** + * The size of Vectors in `inputCol`. + * @group param + */ + @Since("2.3.0") + val size: IntParam = new IntParam( + this, + "size", + "Size of vectors in column.", + {s: Int => s >= 0}) + + /** group getParam */ + @Since("2.3.0") + def getSize: Int = getOrDefault(size) + + /** @group setParam */ + @Since("2.3.0") + def setSize(value: Int): this.type = set(size, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** + * Param for how to handle invalid entries. Invalid vectors include nulls and vectors with the + * wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an + * error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default. + * + * Note: Users should take care when setting this param to `optimistic`. The use of the + * `optimistic` option will prevent the transformer from validating the sizes of vectors in + * `inputCol`. A mismatch between the metadata of a column and its contents could result in + * unexpected behaviour or errors when using that column. + * + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String]( + this, + "handleInvalid", + "How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with " + + "the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` " + + "(throw an error) and `optimistic` (do not check the vector size, and keep all rows). " + + "`error` by default.", + ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids)) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID) + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val localInputCol = getInputCol + val localSize = getSize + val localHandleInvalid = getHandleInvalid + + val group = AttributeGroup.fromStructField(dataset.schema(localInputCol)) + val newGroup = validateSchemaAndSize(dataset.schema, group) + if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) { + dataset.toDF() + } else { + val newCol: Column = localHandleInvalid match { + case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol) + case VectorSizeHint.ERROR_INVALID => + val checkVectorSizeUDF = udf { vector: Vector => + if (vector == null) { + throw new SparkException(s"Got null vector in VectorSizeHint, set `handleInvalid` " + + s"to 'skip' to filter invalid rows.") + } + if (vector.size != localSize) { + throw new SparkException(s"VectorSizeHint Expecting a vector of size $localSize but" + + s" got ${vector.size}") + } + vector + }.asNondeterministic() + checkVectorSizeUDF(col(localInputCol)) + case VectorSizeHint.SKIP_INVALID => + val checkVectorSizeUDF = udf { vector: Vector => + if (vector != null && vector.size == localSize) { + vector + } else { + null + } + } + checkVectorSizeUDF(col(localInputCol)) + } + + val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata())) + if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) { + res.na.drop(Array(localInputCol)) + } else { + res + } + } + } + + /** + * Checks that schema can be updated with new size and returns a new attribute group with + * updated size. + */ + private def validateSchemaAndSize(schema: StructType, group: AttributeGroup): AttributeGroup = { + // This will throw a NoSuchElementException if params are not set. + val localSize = getSize + val localInputCol = getInputCol + + val inputColType = schema(getInputCol).dataType + require( + inputColType.isInstanceOf[VectorUDT], + s"Input column, $getInputCol must be of Vector type, got $inputColType" + ) + group.size match { + case `localSize` => group + case -1 => new AttributeGroup(localInputCol, localSize) + case _ => + val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " + + s"already set to ${group.size}." + throw new IllegalArgumentException(msg) + } + } + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val fieldIndex = schema.fieldIndex(getInputCol) + val fields = schema.fields.clone() + val inputField = fields(fieldIndex) + val group = AttributeGroup.fromStructField(inputField) + val newGroup = validateSchemaAndSize(schema, group) + fields(fieldIndex) = inputField.copy(metadata = newGroup.toMetadata()) + StructType(fields) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +/** :: Experimental :: */ +@Experimental +@Since("2.3.0") +object VectorSizeHint extends DefaultParamsReadable[VectorSizeHint] { + + private[feature] val OPTIMISTIC_INVALID = "optimistic" + private[feature] val ERROR_INVALID = "error" + private[feature] val SKIP_INVALID = "skip" + private[feature] val supportedHandleInvalids: Array[String] = + Array(OPTIMISTIC_INVALID, ERROR_INVALID, SKIP_INVALID) + + @Since("2.3.0") + override def load(path: String): VectorSizeHint = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala new file mode 100644 index 0000000000000..f6c9a76599fae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class VectorSizeHintSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("Test Param Validators") { + intercept[IllegalArgumentException] (new VectorSizeHint().setHandleInvalid("invalidValue")) + intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3)) + } + + test("Required params must be set before transform.") { + val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") + + val noSizeTransformer = new VectorSizeHint().setInputCol("vector") + intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) + + val noInputColTransformer = new VectorSizeHint().setSize(2) + intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) + } + + test("Adding size to column of vectors.") { + + val size = 3 + val vectorColName = "vector" + val denseVector = Vectors.dense(1, 2, 3) + val sparseVector = Vectors.sparse(size, Array(), Array()) + + val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply) + val dataFrame = data.toDF(vectorColName) + assert( + AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1, + s"This test requires that column '$vectorColName' not have size metadata.") + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + val withSize = transformer.transform(dataFrame) + assert( + AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = withSize.collect().length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } + + test("Size hint preserves attributes.") { + + val size = 3 + val vectorColName = "vector" + val data = Seq((1, 2, 3), (2, 3, 3)) + val dataFrame = data.toDF("x", "y", "z") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z")) + .setOutputCol(vectorColName) + val dataFrameWithMetadata = assembler.transform(dataFrame) + val group = AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName)) + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + val withSize = transformer.transform(dataFrameWithMetadata) + + val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + withSize.collect + } + } + + test("Size mismatch between current and target size raises an error.") { + val size = 4 + val vectorColName = "vector" + val data = Seq((1, 2, 3), (2, 3, 3)) + val dataFrame = data.toDF("x", "y", "z") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z")) + .setOutputCol(vectorColName) + val dataFrameWithMetadata = assembler.transform(dataFrame) + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + } + } + + test("Handle invalid does the right thing.") { + + val vector = Vectors.dense(1, 2, 3) + val short = Vectors.dense(2) + val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector") + val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector") + + val sizeHint = new VectorSizeHint() + .setInputCol("vector") + .setHandleInvalid("error") + .setSize(3) + + intercept[SparkException](sizeHint.transform(dataWithNull).collect()) + intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + + sizeHint.setHandleInvalid("skip") + assert(sizeHint.transform(dataWithNull).count() === 1) + assert(sizeHint.transform(dataWithShort).count() === 1) + + sizeHint.setHandleInvalid("optimistic") + assert(sizeHint.transform(dataWithNull).count() === 2) + assert(sizeHint.transform(dataWithShort).count() === 2) + } + + test("read/write") { + val sizeHint = new VectorSizeHint() + .setInputCol("myInputCol") + .setSize(11) + .setHandleInvalid("skip") + testDefaultReadWrite(sizeHint) + } +} + +class VectorSizeHintStreamingSuite extends StreamTest { + + import testImplicits._ + + test("Test assemble vectors with size hint in streaming.") { + val a = Vectors.dense(0, 1, 2) + val b = Vectors.sparse(4, Array(0, 3), Array(3, 6)) + + val stream = MemoryStream[(Vector, Vector)] + val streamingDF = stream.toDS.toDF("a", "b") + val sizeHintA = new VectorSizeHint() + .setSize(3) + .setInputCol("a") + val sizeHintB = new VectorSizeHint() + .setSize(4) + .setInputCol("b") + val vectorAssembler = new VectorAssembler() + .setInputCols(Array("a", "b")) + .setOutputCol("assembled") + val pipeline = new Pipeline().setStages(Array(sizeHintA, sizeHintB, vectorAssembler)) + val output = pipeline.fit(streamingDF).transform(streamingDF).select("assembled") + + val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6) + + testStream (output) ( + AddData(stream, (a, b), (a, b)), + CheckAnswer(Tuple1(expected), Tuple1(expected)) + ) + } +} From 8941a4abcada873c26af924e129173dc33d66d71 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 22 Dec 2017 23:05:03 -0800 Subject: [PATCH 488/712] [SPARK-22789] Map-only continuous processing execution ## What changes were proposed in this pull request? Basic continuous execution, supporting map/flatMap/filter, with commits and advancement through RPC. ## How was this patch tested? new unit-ish tests (exercising execution end to end) Author: Jose Torres Closes #19984 from jose-torres/continuous-impl. --- project/MimaExcludes.scala | 5 + .../UnsupportedOperationChecker.scala | 25 +- .../apache/spark/sql/internal/SQLConf.scala | 21 ++ .../sources/v2/reader/ContinuousReader.java | 6 + .../sources/v2/reader/MicroBatchReader.java | 6 + .../apache/spark/sql/streaming/Trigger.java | 54 +++ .../spark/sql/execution/SparkStrategies.scala | 7 + .../datasources/v2/DataSourceV2ScanExec.scala | 20 +- .../datasources/v2/WriteToDataSourceV2.scala | 60 ++- .../streaming/BaseStreamingSource.java | 8 - .../execution/streaming/HDFSMetadataLog.scala | 14 + .../streaming/MicroBatchExecution.scala | 44 ++- .../sql/execution/streaming/OffsetSeq.scala | 2 +- .../streaming/ProgressReporter.scala | 10 +- .../streaming/RateSourceProvider.scala | 9 +- .../streaming/RateStreamOffset.scala | 5 +- .../spark/sql/execution/streaming/Sink.scala | 2 +- .../sql/execution/streaming/Source.scala | 2 +- .../execution/streaming/StreamExecution.scala | 20 +- .../execution/streaming/StreamProgress.scala | 19 +- .../streaming/StreamingRelation.scala | 47 +++ .../ContinuousDataSourceRDDIter.scala | 217 +++++++++++ .../continuous/ContinuousExecution.scala | 349 ++++++++++++++++++ .../ContinuousRateStreamSource.scala | 11 +- .../continuous/ContinuousTrigger.scala | 70 ++++ .../continuous/EpochCoordinator.scala | 191 ++++++++++ .../sources/RateStreamSourceV2.scala | 19 +- .../streaming/sources/memoryV2.scala | 13 + .../sql/streaming/DataStreamReader.scala | 38 +- .../sql/streaming/DataStreamWriter.scala | 19 +- .../sql/streaming/StreamingQueryManager.scala | 45 ++- .../org/apache/spark/sql/QueryTest.scala | 56 ++- .../streaming/RateSourceV2Suite.scala | 30 +- .../spark/sql/streaming/StreamSuite.scala | 17 +- .../spark/sql/streaming/StreamTest.scala | 55 ++- .../continuous/ContinuousSuite.scala | 316 ++++++++++++++++ 36 files changed, 1682 insertions(+), 150 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9902fedb65d59..81584af6813ea 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-22789: Map-only continuous processing execution + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$9"), + // SPARK-22372: Make cluster submission use SparkApplication. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 04502d04d9509..b55043c270644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -339,6 +339,29 @@ object UnsupportedOperationChecker { } } + def checkForContinuous(plan: LogicalPlan, outputMode: OutputMode): Unit = { + checkForStreaming(plan, outputMode) + + plan.foreachUp { implicit subPlan => + subPlan match { + case (_: Project | _: Filter | _: MapElements | _: MapPartitions | + _: DeserializeToObject | _: SerializeFromObject) => + case node if node.nodeName == "StreamingRelationV2" => + case node => + throwError(s"Continuous processing does not support ${node.nodeName} operations.") + } + + subPlan.expressions.foreach { e => + if (e.collectLeaves().exists { + case (_: CurrentTimestamp | _: CurrentDate) => true + case _ => false + }) { + throwError(s"Continuous processing does not support current time operations.") + } + } + } + } + private def throwErrorIf( condition: Boolean, msg: String)(implicit operator: LogicalPlan): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bdc8d92e84079..84fe4bb711a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1044,6 +1044,22 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.executorQueueSize") + .internal() + .doc("The size (measured in number of rows) of the queue used in continuous execution to" + + " buffer the results of a ContinuousDataReader.") + .intConf + .createWithDefault(1024) + + val CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS = + buildConf("spark.sql.streaming.continuous.executorPollIntervalMs") + .internal() + .doc("The interval at which continuous execution readers will poll to check whether" + + " the epoch has advanced on the driver.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(100) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1357,6 +1373,11 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + + def continuousStreamingExecutorPollIntervalMs: Long = + getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java index 1baf82c2df762..34141d6cd85fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java @@ -65,4 +65,10 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade default boolean needsReconfiguration() { return false; } + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java index 438e3f55b7bcf..bd15c07d87f6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java @@ -61,4 +61,10 @@ public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSourc * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader */ Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index d31790a285687..33ae9a9e87668 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -22,6 +22,7 @@ import scala.concurrent.duration.Duration; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; /** @@ -95,4 +96,57 @@ public static Trigger ProcessingTime(String interval) { public static Trigger Once() { return OneTimeTrigger$.MODULE$; } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * @since 2.3.0 + */ + public static Trigger Continuous(long intervalMs) { + return ContinuousTrigger.apply(intervalMs); + } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * import java.util.concurrent.TimeUnit + * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.3.0 + */ + public static Trigger Continuous(long interval, TimeUnit timeUnit) { + return ContinuousTrigger.create(interval, timeUnit); + } + + /** + * (Scala-friendly) + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * import scala.concurrent.duration._ + * df.writeStream.trigger(Trigger.Continuous(10.seconds)) + * }}} + * @since 2.3.0 + */ + public static Trigger Continuous(Duration interval) { + return ContinuousTrigger.apply(interval); + } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * df.writeStream.trigger(Trigger.Continuous("10 seconds")) + * }}} + * @since 2.3.0 + */ + public static Trigger Continuous(String interval) { + return ContinuousTrigger.apply(interval); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e713cd7bbe2b..8c6c324d456c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,8 +31,10 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.types.StructType /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -374,6 +376,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { StreamingRelationExec(s.sourceName, s.output) :: Nil case s: StreamingExecutionRelation => StreamingRelationExec(s.toString, s.output) :: Nil + case s: StreamingRelationV2 => + StreamingRelationExec(s.sourceName, s.output) :: Nil case _ => Nil } } @@ -404,6 +408,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case MemoryPlanV2(sink, output) => + val encoder = RowEncoder(StructType.fromAttributes(output)) + LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil case logical.Distinct(child) => throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3f243dc44e043..e4fca1b10dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType @@ -52,10 +54,20 @@ case class DataSourceV2ScanExec( }.asJava } - val inputRDD = new DataSourceRDD(sparkContext, readTasks) - .asInstanceOf[RDD[InternalRow]] + val inputRDD = reader match { + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + + case _ => + new DataSourceRDD(sparkContext, readTasks) + } + val numOutputRows = longMetric("numOutputRows") - inputRDD.map { r => + inputRDD.asInstanceOf[RDD[InternalRow]].map { r => numOutputRows += 1 r } @@ -73,7 +85,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) } } -class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) +class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) extends DataReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index b72d15ed15aed..1862da8892cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -58,10 +60,22 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) s"The input RDD has ${messages.length} partitions.") try { + val runTask = writer match { + case w: ContinuousWriter => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.runContinuous(writeTask, context, iter) + case _ => + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter) + } + sparkContext.runJob( rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter), + runTask, rdd.partitions.indices, (index, message: WriterCommitMessage) => messages(index) = message ) @@ -70,6 +84,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) writer.commit(messages) logInfo(s"Data source writer $writer committed.") } catch { + case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => + // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -109,6 +125,44 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for partition ${context.partitionId()} aborted.") }) } + + def runContinuous( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + SparkEnv.get) + val currentMsg: WriterCommitMessage = null + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + + currentMsg + } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java index 3a02cbfe7afe3..c44b8af2552f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming; -import org.apache.spark.sql.sources.v2.reader.Offset; - /** * The shared interface between V1 streaming sources and V2 streaming readers. * @@ -26,12 +24,6 @@ * directly, and will be removed in future versions. */ public interface BaseStreamingSource { - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); - /** Stop this source and free any resources it has allocated. */ void stop(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 43cf0ef1da8ca..6e8154d58d4c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -266,6 +266,20 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } + /** + * Removes all log entries later than thresholdBatchId (exclusive). + */ + def purgeAfter(thresholdBatchId: Long): Unit = { + val batchIds = fileManager.list(metadataPath, batchFilesFilter) + .map(f => pathToBatchId(f.getPath)) + + for (batchId <- batchIds if batchId > thresholdBatchId) { + val path = batchIdToPath(batchId) + fileManager.delete(path) + logTrace(s"Removed metadata log file: $path") + } + } + private def createFileManager(): FileManager = { val hadoopConf = sparkSession.sessionState.newHadoopConf() try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 4a3de8bae4bc9..20f9810faa5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.MicroBatchReadSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -41,6 +42,8 @@ class MicroBatchExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { + @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -53,6 +56,7 @@ class MicroBatchExecution( s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() + val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,6 +68,17 @@ class MicroBatchExecution( // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) + case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) + if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val source = v1DataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output)(sparkSession) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct @@ -170,12 +185,14 @@ class MicroBatchExecution( * Make a call to getBatch using the offsets from previous batch. * because certain sources (e.g., KafkaSource) assume on restart the last * batch will be executed before getOffset is called again. */ - availableOffsets.foreach { ao: (Source, Offset) => - val (source, end) = ao - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + availableOffsets.foreach { + case (source: Source, end: Offset) => + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + case nonV1Tuple => + throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -219,11 +236,12 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) + } }.toMap availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) @@ -298,7 +316,7 @@ class MicroBatchExecution( val prevBatchOff = offsetLog.get(currentBatchId - 1) if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { - case (src, off) => src.commit(off) + case (src: Source, off) => src.commit(off) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -331,7 +349,7 @@ class MicroBatchExecution( // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { - case (source, available) + case (source: Source, available) if committedOffsets.get(source).map(_ != available).getOrElse(true) => val current = committedOffsets.get(source) val batch = source.getBatch(current, available) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 4e0a468b962a2..a1b63a6de3823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -38,7 +38,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * This method is typically used to associate a serialized offset with actual sources (which * cannot be serialized). */ - def toStreamProgress(sources: Seq[Source]): StreamProgress = { + def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { assert(sources.size == offsets.size) new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index b1c3a8ab235ab..1c9043613cb69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Clock trait ProgressReporter extends Logging { case class ExecutionStats( - inputRows: Map[Source, Long], + inputRows: Map[BaseStreamingSource, Long], stateOperators: Seq[StateOperatorProgress], eventTimeStats: Map[String, String]) @@ -53,11 +53,11 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[Source, DataFrame] + protected def newData: Map[BaseStreamingSource, DataFrame] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress - protected def sources: Seq[Source] - protected def sink: Sink + protected def sources: Seq[BaseStreamingSource] + protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata protected def currentBatchId: Long protected def sparkSession: SparkSession @@ -230,7 +230,7 @@ trait ProgressReporter extends Logging { } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[Source, Long] = + val numInputRows: Map[BaseStreamingSource, Long] = if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 41761324cf6ac..3f85fa913f28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -52,7 +52,7 @@ import org.apache.spark.util.{ManualClock, SystemClock} * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with MicroBatchReadSupport with ContinuousReadSupport{ + with DataSourceV2 with ContinuousReadSupport { override def sourceSchema( sqlContext: SQLContext, @@ -107,13 +107,6 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister ) } - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = { - new RateStreamV2Reader(options) - } - override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 726d8574af52b..65d6d18936167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -22,8 +22,11 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 -case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, (Long, Long)]) +case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) extends v2.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } + + +case class ValueRunTimeMsPair(value: Long, runTimeMs: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index d10cd3044ecdf..34bc085d920c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.DataFrame * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same * batch. */ -trait Sink { +trait Sink extends BaseStreamingSink { /** * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 311942f6dbd84..dbbd59e06909c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source extends BaseStreamingSource { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 129995dcf3607..3e76bf7b7ca8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -44,6 +44,7 @@ trait State case object INITIALIZING extends State case object ACTIVE extends State case object TERMINATED extends State +case object RECONFIGURING extends State /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. @@ -59,7 +60,7 @@ abstract class StreamExecution( override val name: String, private val checkpointRoot: String, analyzedPlan: LogicalPlan, - val sink: Sink, + val sink: BaseStreamingSink, val trigger: Trigger, val triggerClock: Clock, val outputMode: OutputMode, @@ -147,30 +148,25 @@ abstract class StreamExecution( * Pretty identified string of printing in logs. Format is * If name is set "queryName [id = xyz, runId = abc]" else "[id = xyz, runId = abc]" */ - private val prettyIdString = + protected val prettyIdString = Option(name).map(_ + " ").getOrElse("") + s"[id = $id, runId = $runId]" - /** - * All stream sources present in the query plan. This will be set when generating logical plan. - */ - @volatile protected var sources: Seq[Source] = Seq.empty - /** * A list of unique sources in the query plan. This will be set when generating logical plan. */ - @volatile protected var uniqueSources: Seq[Source] = Seq.empty + @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty /** Defines the internal state of execution */ - private val state = new AtomicReference[State](INITIALIZING) + protected val state = new AtomicReference[State](INITIALIZING) @volatile var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[Source, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, DataFrame] = _ @volatile - private var streamDeathCause: StreamingQueryException = null + protected var streamDeathCause: StreamingQueryException = null /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() @@ -389,7 +385,7 @@ abstract class StreamExecution( } /** Stops all streaming sources safely. */ - private def stopSources(): Unit = { + protected def stopSources(): Unit = { uniqueSources.foreach { source => try { source.stop() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index a3f3662e6f4c9..8531070b1bc49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -23,25 +23,28 @@ import scala.collection.{immutable, GenTraversableOnce} * A helper class that looks like a Map[Source, Offset]. */ class StreamProgress( - val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) - extends scala.collection.immutable.Map[Source, Offset] { + val baseMap: immutable.Map[BaseStreamingSource, Offset] = + new immutable.HashMap[BaseStreamingSource, Offset]) + extends scala.collection.immutable.Map[BaseStreamingSource, Offset] { - def toOffsetSeq(source: Seq[Source], metadata: OffsetSeqMetadata): OffsetSeq = { + def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - override def +[B1 >: Offset](kv: (Source, B1)): Map[Source, B1] = baseMap + kv + override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = { + baseMap + kv + } - override def get(key: Source): Option[Offset] = baseMap.get(key) + override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key) - override def iterator: Iterator[(Source, Offset)] = baseMap.iterator + override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator - override def -(key: Source): Map[Source, Offset] = baseMap - key + override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key - def ++(updates: GenTraversableOnce[(Source, Offset)]): StreamProgress = { + def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = { new StreamProgress(baseMap ++ updates) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 6b82c78ea653d..0ca2e7854d94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -75,6 +76,52 @@ case class StreamingExecutionRelation( ) } +// We have to pack in the V1 data source as a shim, for the case when a source implements +// continuous processing (which is always V2) but only has V1 microbatch support. We don't +// know at read time whether the query is conntinuous or not, so we need to be able to +// swap a V1 relation back in. +/** + * Used to link a [[DataSourceV2]] into a streaming + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating + * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], + * and should be converted before passing to [[StreamExecution]]. + */ +case class StreamingRelationV2( + dataSource: DataSourceV2, + sourceName: String, + extraOptions: Map[String, String], + output: Seq[Attribute], + v1DataSource: DataSource)(session: SparkSession) + extends LeafNode { + override def isStreaming: Boolean = true + override def toString: String = sourceName + + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) +} + +/** + * Used to link a [[DataSourceV2]] into a continuous processing execution. + */ +case class ContinuousExecutionRelation( + source: ContinuousReadSupport, + extraOptions: Map[String, String], + output: Seq[Attribute])(session: SparkSession) + extends LeafNode { + + override def isStreaming: Boolean = true + override def toString: String = source.toString + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) +} + /** * A dummy physical plan for [[StreamingRelation]] to support * [[org.apache.spark.sql.Dataset.explain]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala new file mode 100644 index 0000000000000..89fb2ace20917 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.util.{SystemClock, ThreadUtils} + +class ContinuousDataSourceRDD( + sc: SparkContext, + sqlContext: SQLContext, + @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize + private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs + + override protected def getPartitions: Array[Partition] = { + readTasks.asScala.zipWithIndex.map { + case (readTask, index) => new DataSourceRDDPartition(index, readTask) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + + val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + + // This queue contains two types of messages: + // * (null, null) representing an epoch boundary. + // * (row, off) containing a data row and its corresponding PartitionOffset. + val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) + + val epochPollFailed = new AtomicBoolean(false) + val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--${runId}--${context.partitionId()}") + val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) + epochPollExecutor.scheduleWithFixedDelay( + epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + // Important sequencing - we must get start offset before the data reader thread begins + val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset + + val dataReaderFailed = new AtomicBoolean(false) + val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + reader.close() + dataReaderThread.interrupt() + epochPollExecutor.shutdown() + }) + + val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + new Iterator[UnsafeRow] { + private val POLL_TIMEOUT_MS = 1000 + + private var currentEntry: (UnsafeRow, PartitionOffset) = _ + private var currentOffset: PartitionOffset = startOffset + private var currentEpoch = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def hasNext(): Boolean = { + while (currentEntry == null) { + if (context.isInterrupted() || context.isCompleted()) { + currentEntry = (null, null) + } + if (dataReaderFailed.get()) { + throw new SparkException("data read failed", dataReaderThread.failureReason) + } + if (epochPollFailed.get()) { + throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + + currentEntry match { + // epoch boundary marker + case (null, null) => + epochEndpoint.send(ReportPartitionOffset( + context.partitionId(), + currentEpoch, + currentOffset)) + currentEpoch += 1 + currentEntry = null + false + // real row + case (_, offset) => + currentOffset = offset + true + } + } + + override def next(): UnsafeRow = { + if (currentEntry == null) throw new NoSuchElementException("No current row was set") + val r = currentEntry._1 + currentEntry = null + r + } + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + } +} + +case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset + +class EpochPollRunnable( + queue: BlockingQueue[(UnsafeRow, PartitionOffset)], + context: TaskContext, + failedFlag: AtomicBoolean) + extends Thread with Logging { + private[continuous] var failureReason: Throwable = _ + + private val epochEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) + for (i <- currentEpoch to newEpoch - 1) { + queue.put((null, null)) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + failedFlag.set(true) + throw t + } + } +} + +class DataReaderThread( + reader: DataReader[UnsafeRow], + queue: BlockingQueue[(UnsafeRow, PartitionOffset)], + context: TaskContext, + failedFlag: AtomicBoolean) + extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) + try { + while (!context.isInterrupted && !context.isCompleted()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!context.isInterrupted && !context.isCompleted()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put((reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException if context.isInterrupted() => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + + case t: Throwable => + failureReason = t + failedFlag.set(true) + // Don't rethrow the exception in this thread. It's not needed, and the default Spark + // exception handler will kill the executor. + } + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala new file mode 100644 index 0000000000000..1c35b06bd4b85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, ContinuousWriteSupport, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.ContinuousWriter +import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{Clock, Utils} + +class ContinuousExecution( + sparkSession: SparkSession, + name: String, + checkpointRoot: String, + analyzedPlan: LogicalPlan, + sink: ContinuousWriteSupport, + trigger: Trigger, + triggerClock: Clock, + outputMode: OutputMode, + extraOptions: Map[String, String], + deleteCheckpointOnStop: Boolean) + extends StreamExecution( + sparkSession, name, checkpointRoot, analyzedPlan, sink, + trigger, triggerClock, outputMode, deleteCheckpointOnStop) { + + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + override protected def sources: Seq[BaseStreamingSource] = continuousSources + + override lazy val logicalPlan: LogicalPlan = { + assert(queryExecutionThread eq Thread.currentThread, + "logicalPlan must be initialized in StreamExecutionThread " + + s"but the current thread was ${Thread.currentThread}") + val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() + analyzedPlan.transform { + case r @ StreamingRelationV2( + source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + toExecutionRelationMap.getOrElseUpdate(r, { + ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) + }) + case StreamingRelationV2(_, sourceName, _, _, _) => + throw new AnalysisException( + s"Data source $sourceName does not support continuous processing.") + } + } + + private val triggerExecutor = trigger match { + case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTime(t), triggerClock) + case _ => throw new IllegalStateException(s"Unsupported type of trigger: $trigger") + } + + override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { + do { + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). This function must be called + * before any processing occurs and will populate the following fields: + * - currentBatchId + * - committedOffsets + * The basic structure of this method is as follows: + * + * Identify (from the commit log) the latest epoch that has committed + * IF last epoch exists THEN + * Get end offsets for the epoch + * Set those offsets as the current commit progress + * Set the next epoch ID as the last + 1 + * Return the end offsets of the last epoch as start for the next one + * DONE + * ELSE + * Start a new query log + * DONE + */ + private def getStartOffsets(sparkSessionToRunBatches: SparkSession): OffsetSeq = { + // Note that this will need a slight modification for exactly once. If ending offsets were + // reported but not committed for any epochs, we must replay exactly to those offsets. + // For at least once, we can just ignore those reports and risk duplicates. + commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val nextOffsets = offsetLog.get(latestEpochId).getOrElse { + throw new IllegalStateException( + s"Batch $latestEpochId was committed without end epoch offsets!") + } + committedOffsets = nextOffsets.toStreamProgress(sources) + + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + + currentBatchId = latestEpochId + 1 + logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") + nextOffsets + case None => + // We are starting this stream for the first time. Offsets are all None. + logInfo(s"Starting new streaming query.") + currentBatchId = 0 + OffsetSeq.fill(continuousSources.map(_ => null): _*) + } + } + + /** + * Do a continuous run. + * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. + */ + private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + // A list of attributes that will need to be updated. + val replacements = new ArrayBuffer[(Attribute, Attribute)] + // Translate from continuous relation to the underlying data source. + var nextSourceId = 0 + continuousSources = logicalPlan.collect { + case ContinuousExecutionRelation(dataSource, extraReaderOptions, output) => + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + nextSourceId += 1 + + dataSource.createContinuousReader( + java.util.Optional.empty[StructType](), + metadataPath, + new DataSourceV2Options(extraReaderOptions.asJava)) + } + uniqueSources = continuousSources.distinct + + val offsets = getStartOffsets(sparkSessionForQuery) + + var insertedSourceId = 0 + val withNewSources = logicalPlan transform { + case ContinuousExecutionRelation(_, _, output) => + val reader = continuousSources(insertedSourceId) + insertedSourceId += 1 + val newOutput = reader.readSchema().toAttributes + + assert(output.size == newOutput.size, + s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + + s"${Utils.truncatedString(newOutput, ",")}") + replacements ++= output.zip(newOutput) + + val loggedOffset = offsets.offsets(0) + val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) + reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) + DataSourceV2Relation(newOutput, reader) + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val triggerLogicalPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => + replacementMap(a).withMetadata(a.metadata) + case (_: CurrentTimestamp | _: CurrentDate) => + throw new IllegalStateException( + "CurrentTimestamp and CurrentDate not yet supported for continuous processing") + } + + val writer = sink.createContinuousWriter( + s"$runId", + triggerLogicalPlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + + val reader = withSink.collect { + case DataSourceV2Relation(_, r: ContinuousReader) => r + }.head + + reportTimeTaken("queryPlanning") { + lastExecution = new IncrementalExecution( + sparkSessionForQuery, + withSink, + outputMode, + checkpointFile("state"), + runId, + currentBatchId, + offsetSeqMetadata) + lastExecution.executedPlan // Force the lazy generation of execution plan + } + + sparkSession.sparkContext.setLocalProperty( + ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) + sparkSession.sparkContext.setLocalProperty( + ContinuousExecution.RUN_ID_KEY, runId.toString) + + // Use the parent Spark session for the endpoint since it's where this query ID is registered. + val epochEndpoint = + EpochCoordinatorRef.create( + writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + val epochUpdateThread = new Thread(new Runnable { + override def run: Unit = { + try { + triggerExecutor.execute(() => { + startTrigger() + + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) + stopSources() + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. + } + false + } else if (isActive) { + currentBatchId = epochEndpoint.askSync[Long](IncrementAndGetEpoch) + logInfo(s"New epoch $currentBatchId is starting.") + true + } else { + false + } + }) + } catch { + case _: InterruptedException => + // Cleanly stop the query. + return + } + } + }, s"epoch update thread for $prettyIdString") + + try { + epochUpdateThread.setDaemon(true) + epochUpdateThread.start() + + reportTimeTaken("runContinuous") { + SQLExecution.withNewExecutionId( + sparkSessionForQuery, lastExecution)(lastExecution.toRdd) + } + } finally { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + + epochUpdateThread.interrupt() + epochUpdateThread.join() + } + } + + /** + * Report ending partition offsets for the given reader at the given epoch. + */ + def addOffset( + epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + assert(continuousSources.length == 1, "only one continuous source supported currently") + + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. + } + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } + } + } + + /** + * Mark the specified epoch as committed. All readers must have reported end offsets for the epoch + * before this is called. + */ + def commit(epoch: Long): Unit = { + assert(continuousSources.length == 1, "only one continuous source supported currently") + assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + if (queryExecutionThread.isAlive) { + commitLog.add(epoch) + val offset = offsetLog.get(epoch).get.offsets(0).get + committedOffsets ++= Seq(continuousSources(0) -> offset) + } else { + return + } + } + + if (minLogEntriesToMaintain < currentBatchId) { + offsetLog.purge(currentBatchId - minLogEntriesToMaintain) + commitLog.purge(currentBatchId - minLogEntriesToMaintain) + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() + } + } + + /** + * Blocks the current thread until execution has committed at or after the specified epoch. + */ + private[sql] def awaitEpoch(epoch: Long): Unit = { + def notDone = { + val latestCommit = commitLog.getLatest() + latestCommit match { + case Some((latestEpoch, _)) => + latestEpoch < epoch + case None => true + } + } + + while (notDone) { + awaitProgressLock.lock() + try { + awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } + } finally { + awaitProgressLock.unlock() + } + } + } +} + +object ContinuousExecution { + val START_EPOCH_KEY = "__continuous_start_epoch" + val RUN_ID_KEY = "__run_id" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 4c3a1ee201ac1..89a8562b4b59e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ @@ -47,13 +47,14 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => (i, (currVal, nextRead)) + case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) } override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } override def readSchema(): StructType = RateSourceProvider.SCHEMA @@ -85,8 +86,8 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. RateStreamReadTask( - start._1, // starting row value - start._2, // starting time in ms + start.value, + start.runTimeMs, i, numPartitions, perPartitionRate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala new file mode 100644 index 0000000000000..90e1766c4d9f1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + */ +@InterfaceStability.Evolving +case class ContinuousTrigger(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") +} + +private[sql] object ContinuousTrigger { + def apply(interval: String): ContinuousTrigger = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "interval cannot be null or blank.") + } + val cal = if (interval.startsWith("interval")) { + CalendarInterval.fromString(interval) + } else { + CalendarInterval.fromString("interval " + interval) + } + if (cal == null) { + throw new IllegalArgumentException(s"Invalid interval: $interval") + } + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + new ContinuousTrigger(cal.microseconds / 1000) + } + + def apply(interval: Duration): ContinuousTrigger = { + ContinuousTrigger(interval.toMillis) + } + + def create(interval: String): ContinuousTrigger = { + apply(interval) + } + + def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { + ContinuousTrigger(unit.toMillis(interval)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala new file mode 100644 index 0000000000000..7f1e8abd79b99 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, WriterCommitMessage} +import org.apache.spark.util.RpcUtils + +private[continuous] sealed trait EpochCoordinatorMessage extends Serializable + +// Driver epoch trigger message +/** + * Atomically increment the current epoch and get the new value. + */ +private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage + +// Init messages +/** + * Set the reader and writer partition counts. Tasks may not be started until the coordinator + * has acknowledged these messages. + */ +private[sql] case class SetReaderPartitions(numPartitions: Int) extends EpochCoordinatorMessage +case class SetWriterPartitions(numPartitions: Int) extends EpochCoordinatorMessage + +// Partition task messages +/** + * Get the current epoch. + */ +private[sql] case object GetCurrentEpoch extends EpochCoordinatorMessage +/** + * Commit a partition at the specified epoch with the given message. + */ +private[sql] case class CommitPartitionEpoch( + partitionId: Int, + epoch: Long, + message: WriterCommitMessage) extends EpochCoordinatorMessage +/** + * Report that a partition is ending the specified epoch at the specified offset. + */ +private[sql] case class ReportPartitionOffset( + partitionId: Int, + epoch: Long, + offset: PartitionOffset) extends EpochCoordinatorMessage + + +/** Helper object used to create reference to [[EpochCoordinator]]. */ +private[sql] object EpochCoordinatorRef extends Logging { + private def endpointName(runId: String) = s"EpochCoordinator-$runId" + + /** + * Create a reference to a new [[EpochCoordinator]]. + */ + def create( + writer: ContinuousWriter, + reader: ContinuousReader, + query: ContinuousExecution, + startEpoch: Long, + session: SparkSession, + env: SparkEnv): RpcEndpointRef = synchronized { + val coordinator = new EpochCoordinator( + writer, reader, query, startEpoch, session, env.rpcEnv) + val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + logInfo("Registered EpochCoordinator endpoint") + ref + } + + def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + logDebug("Retrieved existing EpochCoordinator endpoint") + rpcEndpointRef + } +} + +/** + * Handles three major epoch coordination tasks for continuous processing: + * + * * Maintains a local epoch counter (the "driver epoch"), incremented by IncrementAndGetEpoch + * and pollable from executors by GetCurrentEpoch. Note that this epoch is *not* immediately + * reflected anywhere in ContinuousExecution. + * * Collates ReportPartitionOffset messages, and forwards to ContinuousExecution when all + * readers have ended a given epoch. + * * Collates CommitPartitionEpoch messages, and forwards to ContinuousExecution when all readers + * have both committed and reported an end offset for a given epoch. + */ +private[continuous] class EpochCoordinator( + writer: ContinuousWriter, + reader: ContinuousReader, + query: ContinuousExecution, + startEpoch: Long, + session: SparkSession, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + + private var numReaderPartitions: Int = _ + private var numWriterPartitions: Int = _ + + private var currentDriverEpoch = startEpoch + + // (epoch, partition) -> message + private val partitionCommits = + mutable.Map[(Long, Int), WriterCommitMessage]() + // (epoch, partition) -> offset + private val partitionOffsets = + mutable.Map[(Long, Int), PartitionOffset]() + + private def resolveCommitsAtEpoch(epoch: Long) = { + val thisEpochCommits = + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val nextEpochOffsets = + partitionOffsets.collect { case ((e, _), o) if e == epoch => o } + + if (thisEpochCommits.size == numWriterPartitions && + nextEpochOffsets.size == numReaderPartitions) { + logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, thisEpochCommits.toArray) + query.commit(epoch) + + // Cleanup state from before this epoch, now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { + partitionCommits.remove(k) + } + } + } + + override def receive: PartialFunction[Any, Unit] = { + case CommitPartitionEpoch(partitionId, epoch, message) => + logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") + if (!partitionCommits.isDefinedAt((epoch, partitionId))) { + partitionCommits.put((epoch, partitionId), message) + resolveCommitsAtEpoch(epoch) + } + + case ReportPartitionOffset(partitionId, epoch, offset) => + partitionOffsets.put((epoch, partitionId), offset) + val thisEpochOffsets = + partitionOffsets.collect { case ((e, _), o) if e == epoch => o } + if (thisEpochOffsets.size == numReaderPartitions) { + logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") + query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + resolveCommitsAtEpoch(epoch) + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetCurrentEpoch => + val result = currentDriverEpoch + logDebug(s"Epoch $result") + context.reply(result) + + case IncrementAndGetEpoch => + currentDriverEpoch += 1 + context.reply(currentDriverEpoch) + + case SetReaderPartitions(numPartitions) => + numReaderPartitions = numPartitions + context.reply(()) + + case SetWriterPartitions(numPartitions) => + numWriterPartitions = numPartitions + context.reply(()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 45dc7d75cbc8d..1c66aed8690a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.RateStreamOffset +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} @@ -71,7 +71,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val currentTime = clock.getTimeMillis() RateStreamOffset( this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, (currentVal, currentReadTime)) => + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => // Calculate the number of rows we should advance in this partition (based on the // current time), and output a corresponding offset. val readInterval = currentTime - currentReadTime @@ -79,9 +79,9 @@ class RateStreamV2Reader(options: DataSourceV2Options) if (numNewRows <= 0) { startOffset } else { - (part, - (currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) } } ) @@ -98,15 +98,15 @@ class RateStreamV2Reader(options: DataSourceV2Options) } override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } override def createReadTasks(): java.util.List[ReadTask[Row]] = { val startMap = start.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs endMap.keys.toSeq.map { part => - val (endVal, _) = endMap(part) - val (startVal, startTimeMs) = startMap(part) + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions @@ -158,7 +158,8 @@ object RateStreamSourceV2 { // by the increment that will later be applied. The first row output in each // partition will have a value equal to the partition index. (i, - ((i - numPartitions).toLong, + ValueRunTimeMsPair( + (i - numPartitions).toLong, creationTimeMs)) }.toMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 94c5dd63089b1..972248d5e4df8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} @@ -177,3 +179,14 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) override def abort(): Unit = {} } + + +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { + private val sizePerRow = output.map(_.dataType.defaultSize).sum + + override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 41aa02c2b5e35..f17935e86f459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,8 +26,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, @@ -153,13 +155,33 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo "read files of Hive data source directly.") } - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() + val options = new DataSourceV2Options(extraOptions.asJava) + // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. + // We can't be sure at this point whether we'll actually want to use V2, since we don't know the + // writer or whether the query is continuous. + val v1DataSource = DataSource( + sparkSession, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + ds match { + case s: ContinuousReadSupport => + val tempReader = s.createContinuousReader( + java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + // Generate the V1 node to catch errors thrown within generation. + StreamingRelation(v1DataSource) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + case _ => + // Code path for data source v1. + Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0be69b98abc8a..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -26,7 +26,9 @@ import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -240,14 +242,23 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } - val sink = new MemorySink(df.schema, outputMode) - val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) + val (sink, resultDf) = trigger match { + case _: ContinuousTrigger => + val s = new MemorySinkV2() + val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) + (s, r) + case _ => + val s = new MemorySink(df.schema, outputMode) + val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) + (s, r) + } val chkpointLoc = extraOptions.get("checkpointLocation") val recoverFromChkpoint = outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), chkpointLoc, df, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = true, @@ -262,6 +273,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = true, @@ -277,6 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, + extraOptions.toMap, dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 555d6e23f9385..e808ffaa96410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,8 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.ContinuousWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -188,7 +190,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], df: DataFrame, - sink: Sink, + extraOptions: Map[String, String], + sink: BaseStreamingSink, outputMode: OutputMode, useTempCheckpointLocation: Boolean, recoverFromCheckpointLocation: Boolean, @@ -237,16 +240,32 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - new StreamingQueryWrapper(new MicroBatchExecution( - sparkSession, - userSpecifiedName.orNull, - checkpointLocation, - analyzedPlan, - sink, - trigger, - triggerClock, - outputMode, - deleteCheckpointOnStop)) + sink match { + case v1Sink: Sink => + new StreamingQueryWrapper(new MicroBatchExecution( + sparkSession, + userSpecifiedName.orNull, + checkpointLocation, + analyzedPlan, + v1Sink, + trigger, + triggerClock, + outputMode, + deleteCheckpointOnStop)) + case v2Sink: ContinuousWriteSupport => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( + sparkSession, + userSpecifiedName.orNull, + checkpointLocation, + analyzedPlan, + v2Sink, + trigger, + triggerClock, + outputMode, + extraOptions, + deleteCheckpointOnStop)) + } } /** @@ -269,7 +288,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], df: DataFrame, - sink: Sink, + extraOptions: Map[String, String], + sink: BaseStreamingSink, outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, @@ -279,6 +299,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName, userSpecifiedCheckpointLocation, df, + extraOptions, sink, outputMode, useTempCheckpointLocation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index fcaca3d75b74f..9fb8be423614b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -297,31 +297,47 @@ object QueryTest { }) } + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + + s""" + |== Results == + |${ + sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + } + """.stripMargin + } + + def includesRows( + expectedRows: Seq[Row], + sparkAnswer: Seq[Row]): Option[String] = { + if (!prepareAnswer(expectedRows, true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) { + return Some(genError(expectedRows, sparkAnswer, true)) + } + None + } + def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { - val getRowType: Option[Row] => String = row => - row.map(row => - if (row.schema == null) { - "struct<>" - } else { - s"${row.schema.catalogString}" - }).getOrElse("struct<>") - - val errorMessage = - s""" - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - getRowType(expectedAnswer.headOption) +: - prepareAnswer(expectedAnswer, isSorted).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - getRowType(sparkAnswer.headOption) +: - prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) } None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 6514c5f0fdfeb..dc833b2ccaa22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -29,16 +29,6 @@ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Optio import org.apache.spark.sql.streaming.StreamTest class RateSourceV2Suite extends StreamTest { - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[RateStreamV2Reader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - test("microbatch - numPartitions propagated") { val reader = new RateStreamV2Reader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) @@ -49,8 +39,8 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - set offset") { val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) - val startOffset = RateStreamOffset(Map((0, (0, 1000)))) - val endOffset = RateStreamOffset(Map((0, (0, 2000)))) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) assert(reader.getStartOffset() == startOffset) assert(reader.getEndOffset() == endOffset) @@ -63,15 +53,15 @@ class RateSourceV2Suite extends StreamTest { reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.getStartOffset() match { case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0)._2 == reader.creationTimeMs) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) case _ => throw new IllegalStateException("unexpected offset type") } reader.getEndOffset() match { case r: RateStreamOffset => // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0)._1 >= 9) - assert(r.partitionToValueAndRunTimeMs(0)._2 >= reader.creationTimeMs + 100) + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -80,8 +70,8 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - predetermined batch size") { val reader = new RateStreamV2Reader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, (0, 1000)))) - val endOffset = RateStreamOffset(Map((0, (20, 2000)))) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) val tasks = reader.createReadTasks() assert(tasks.size == 1) @@ -93,8 +83,8 @@ class RateSourceV2Suite extends StreamTest { new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, (currentVal, currentReadTime)) => - (part, (currentVal + 33, currentReadTime + 1000)) + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) }.toMap) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -135,7 +125,7 @@ class RateSourceV2Suite extends StreamTest { val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) - ._2 + .runTimeMs val r = t.createDataReader().asInstanceOf[RateStreamDataReader] for (rowIndex <- 0 to 9) { r.next() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 755490308b5b9..c65e5d3dd75c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -77,10 +77,23 @@ class StreamSuite extends StreamTest { } test("StreamingRelation.computeStats") { + withTempDir { dir => + val df = spark.readStream.format("csv").schema(StructType(Seq())).load(dir.getCanonicalPath) + val streamingRelation = df.logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == + spark.sessionState.conf.defaultSizeInBytes) + } + } + + test("StreamingRelationV2.computeStats") { val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { - case s: StreamingRelation => s + case s: StreamingRelationV2 => s } - assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert(streamingRelation.nonEmpty, "cannot find StreamingExecutionRelation") assert( streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 71a474ef63e84..fb9ebc81dd750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -33,11 +33,14 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext @@ -168,6 +171,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } + case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + } + case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${checkFunction.toString()}" @@ -237,6 +246,25 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AssertOnQuery(query => { func(query); true }) } + object AwaitEpoch { + def apply(epoch: Long): AssertOnQuery = + Execute { + case s: ContinuousExecution => s.awaitEpoch(epoch) + case _ => throw new IllegalStateException("microbatch cannot await epoch") + } + } + + object IncrementEpoch { + def apply(): AssertOnQuery = + Execute { + case s: ContinuousExecution => + val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + .askSync[Long](IncrementAndGetEpoch) + s.awaitEpoch(newEpoch - 1) + case _ => throw new IllegalStateException("microbatch cannot increment epoch") + } + } + /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -246,7 +274,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be */ def testStream( _stream: Dataset[_], - outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + outputMode: OutputMode = OutputMode.Append, + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -259,7 +288,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() @volatile @@ -308,7 +337,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be "" } - def testState = + def testState = { + val sinkDebugString = sink match { + case s: MemorySink => s.toDebugString + case s: MemorySinkV2 => s.toDebugString + } s""" |== Progress == |$testActions @@ -321,12 +354,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == - |${sink.toDebugString} + |$sinkDebugString | | |== Plan == |${if (currentStream != null) currentStream.lastExecution else ""} """.stripMargin + } def verify(condition: => Boolean, message: String): Unit = { if (!condition) { @@ -383,7 +417,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - try if (lastOnly) sink.latestBatchData else sink.allData catch { + val (latestBatchData, allData) = sink match { + case s: MemorySink => (s.latestBatchData, s.allData) + case s: MemorySinkV2 => (s.latestBatchData, s.allData) + } + try if (lastOnly) latestBatchData else allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } @@ -423,6 +461,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be None, Some(metadataRoot), stream, + Map(), sink, outputMode, trigger = trigger, @@ -594,6 +633,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } + case CheckAnswerRowsByFunc(checkFunction, lastOnly) => val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) sparkAnswer.foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala new file mode 100644 index 0000000000000..eda0d8ad48313 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} + +import scala.reflect.ClassTag +import scala.util.control.ControlThrowable + +import com.google.common.util.concurrent.UncheckedExecutionException +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class ContinuousSuiteBase extends StreamTest { + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + protected def waitForRateSourceTriggers(query: StreamExecution, numTriggers: Int): Unit = { + query match { + case s: ContinuousExecution => + assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") + val reader = s.lastExecution.executedPlan.collectFirst { + case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + }.get + + val deltaMs = numTriggers * 1000 + 300 + while (System.currentTimeMillis < reader.creationTime + deltaMs) { + Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis) + } + } + } + + // A continuous trigger that will only fire the initial time for the duration of a test. + // This allows clean testing with manual epoch advancement. + protected val longContinuousTrigger = Trigger.Continuous("1 hour") +} + +class ContinuousSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("basic rate source") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + StopStream, + StartStream(longContinuousTrigger), + AwaitEpoch(2), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StopStream) + } + + test("map") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .map(r => r.getLong(0) * 2) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + } + + test("flatMap") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + } + + test("filter") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .where('value > 5) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + + test("deduplicate") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .dropDuplicates() + + val except = intercept[AnalysisException] { + testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + } + + assert(except.message.contains( + "Continuous processing does not support Deduplicate operations.")) + } + + test("timestamp") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select(current_timestamp()) + + val except = intercept[AnalysisException] { + testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + } + + assert(except.message.contains( + "Continuous processing does not support current time operations.")) + } + + test("repeatedly restart") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + StopStream, + StartStream(longContinuousTrigger), + StopStream, + StartStream(longContinuousTrigger), + StopStream, + StartStream(longContinuousTrigger), + AwaitEpoch(2), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StopStream) + } + + test("query without test harness") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "2") + .option("rowsPerSecond", "2") + .load() + .select('value) + val query = df.writeStream + .format("memory") + .queryName("noharness") + .trigger(Trigger.Continuous(100)) + .start() + val continuousExecution = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution] + continuousExecution.awaitEpoch(0) + waitForRateSourceTriggers(continuousExecution, 2) + query.stop() + + val results = spark.read.table("noharness").collect() + assert(Set(0, 1, 2, 3).map(Row(_)).subsetOf(results.toSet)) + } +} + +class ContinuousStressSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("only one epoch") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 201)), + IncrementEpoch(), + Execute { query => + val data = query.sink.asInstanceOf[MemorySinkV2].allData + val vals = data.map(_.getLong(0)).toSet + assert(scala.Range(0, 25000).forall { i => + vals.contains(i) + }) + }) + } + + test("automatic epoch advancement") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 201)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) + } + + test("restarts") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(10), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(20), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(21), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(22), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(25), + StopStream, + StartStream(Trigger.Continuous(2012)), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(50), + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) + } +} From 86db9b2d7d3c3c8e6c27b05cb36318dba9138e78 Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Sat, 23 Dec 2017 08:13:34 -0600 Subject: [PATCH 489/712] [SPARK-22833][IMPROVEMENT] in SparkHive Scala Examples ## What changes were proposed in this pull request? SparkHive Scala Examples Improvement made: * Writing DataFrame / DataSet to Hive Managed , Hive External table using different storage format. * Implementation of Partition, Reparition, Coalesce with appropriate example. ## How was this patch tested? * Patch has been tested manually and by running ./dev/run-tests. Author: chetkhatri Closes #20018 from chetkhatri/scala-sparkhive-examples. --- .../examples/sql/hive/SparkHiveExample.scala | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index e5f75d53edc86..51df5dd8e3600 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -19,8 +19,7 @@ package org.apache.spark.examples.sql.hive // $example on:spark_hive$ import java.io.File -import org.apache.spark.sql.Row -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SaveMode, SparkSession} // $example off:spark_hive$ object SparkHiveExample { @@ -102,8 +101,41 @@ object SparkHiveExample { // | 4| val_4| 4| val_4| // | 5| val_5| 5| val_5| // ... - // $example off:spark_hive$ + // Create Hive managed table with Parquet + sql("CREATE TABLE records(key int, value string) STORED AS PARQUET") + // Save DataFrame to Hive managed table as Parquet format + val hiveTableDF = sql("SELECT * FROM records") + hiveTableDF.write.mode(SaveMode.Overwrite).saveAsTable("database_name.records") + // Create External Hive table with Parquet + sql("CREATE EXTERNAL TABLE records(key int, value string) " + + "STORED AS PARQUET LOCATION '/user/hive/warehouse/'") + // to make Hive Parquet format compatible with Spark Parquet format + spark.sqlContext.setConf("spark.sql.parquet.writeLegacyFormat", "true") + + // Multiple Parquet files could be created accordingly to volume of data under directory given. + val hiveExternalTableLocation = "/user/hive/warehouse/database_name.db/records" + + // Save DataFrame to Hive External table as compatible Parquet format + hiveTableDF.write.mode(SaveMode.Overwrite).parquet(hiveExternalTableLocation) + + // Turn on flag for Dynamic Partitioning + spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") + spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") + + // You can create partitions in Hive table, so downstream queries run much faster. + hiveTableDF.write.mode(SaveMode.Overwrite).partitionBy("key") + .parquet(hiveExternalTableLocation) + + // Reduce number of files for each partition by repartition + hiveTableDF.repartition($"key").write.mode(SaveMode.Overwrite) + .partitionBy("key").parquet(hiveExternalTableLocation) + + // Control the number of files in each partition by coalesce + hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) + .partitionBy("key").parquet(hiveExternalTableLocation) + // $example off:spark_hive$ + spark.stop() } } From ea2642eb0ebcf02e8fba727a82c140c9f2284725 Mon Sep 17 00:00:00 2001 From: CNRui <13266776177@163.com> Date: Sat, 23 Dec 2017 08:18:08 -0600 Subject: [PATCH 490/712] [SPARK-20694][EXAMPLES] Update SQLDataSourceExample.scala ## What changes were proposed in this pull request? Create table using the right DataFrame. peopleDF->usersDF peopleDF: +----+-------+ | age| name| +----+-------+ usersDF: +------+--------------+----------------+ | name|favorite_color|favorite_numbers| +------+--------------+----------------+ ## How was this patch tested? Manually tested. Author: CNRui <13266776177@163.com> Closes #20052 from CNRui/patch-2. --- .../apache/spark/examples/sql/SQLDataSourceExample.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index f9477969a4bb5..7d83aacb11548 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -67,15 +67,15 @@ object SQLDataSourceExample { usersDF.write.partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet") // $example off:write_partitioning$ // $example on:write_partition_and_bucket$ - peopleDF + usersDF .write .partitionBy("favorite_color") .bucketBy(42, "name") - .saveAsTable("people_partitioned_bucketed") + .saveAsTable("users_partitioned_bucketed") // $example off:write_partition_and_bucket$ spark.sql("DROP TABLE IF EXISTS people_bucketed") - spark.sql("DROP TABLE IF EXISTS people_partitioned_bucketed") + spark.sql("DROP TABLE IF EXISTS users_partitioned_bucketed") } private def runBasicParquetExample(spark: SparkSession): Unit = { From f6084a88f0fe69111df8a016bc81c9884d3d3402 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 24 Dec 2017 01:16:12 +0900 Subject: [PATCH 491/712] [HOTFIX] Fix Scala style checks ## What changes were proposed in this pull request? This PR fixes a style that broke the build. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20065 from HyukjinKwon/minor-style. --- .../org/apache/spark/examples/sql/hive/SparkHiveExample.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 51df5dd8e3600..b193bd595127c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -135,7 +135,7 @@ object SparkHiveExample { hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) .partitionBy("key").parquet(hiveExternalTableLocation) // $example off:spark_hive$ - + spark.stop() } } From aeb45df668a97a2d48cfd4079ed62601390979ba Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 24 Dec 2017 01:18:11 +0900 Subject: [PATCH 492/712] [SPARK-22844][R] Adds date_trunc in R API ## What changes were proposed in this pull request? This PR adds `date_trunc` in R API as below: ```r > df <- createDataFrame(list(list(a = as.POSIXlt("2012-12-13 12:34:00")))) > head(select(df, date_trunc("hour", df$a))) date_trunc(hour, a) 1 2012-12-13 12:00:00 ``` ## How was this patch tested? Unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R`. Author: hyukjinkwon Closes #20031 from HyukjinKwon/r-datetrunc. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 34 +++++++++++++++++++++++---- R/pkg/R/generics.R | 5 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 3 +++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 57838f52eac3f..dce64e1e607c8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -230,6 +230,7 @@ exportMethods("%<=>%", "date_add", "date_format", "date_sub", + "date_trunc", "datediff", "dayofmonth", "dayofweek", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 237ef061e8071..3a96f94e269f4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -40,10 +40,17 @@ NULL #' #' @param x Column to compute on. In \code{window}, it must be a time Column of #' \code{TimestampType}. -#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string -#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for -#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param format The format for the given dates or timestamps in Column \code{x}. See the +#' format used in the following methods: +#' \itemize{ +#' \item \code{to_date} and \code{to_timestamp}: it is the string to use to parse +#' Column \code{x} to DateType or TimestampType. +#' \item \code{trunc}: it is the string to use to specify the truncation method. +#' For example, "year", "yyyy", "yy" for truncate by year, or "month", "mon", +#' "mm" for truncate by month. +#' \item \code{date_trunc}: it is similar with \code{trunc}'s but additionally +#' supports "day", "dd", "second", "minute", "hour", "week" and "quarter". +#' } #' @param ... additional argument(s). #' @name column_datetime_functions #' @rdname column_datetime_functions @@ -3478,3 +3485,22 @@ setMethod("trunc", x@jc, as.character(format)) column(jc) }) + +#' @details +#' \code{date_trunc}: Returns timestamp truncated to the unit specified by the format. +#' +#' @rdname column_datetime_functions +#' @aliases date_trunc date_trunc,character,Column-method +#' @export +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, date_trunc("hour", df$time), date_trunc("minute", df$time), +#' date_trunc("week", df$time), date_trunc("quarter", df$time)))} +#' @note date_trunc since 2.3.0 +setMethod("date_trunc", + signature(format = "character", x = "Column"), + function(format, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_trunc", format, x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8fcf269087c7d..5ddaa669f9205 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1043,6 +1043,11 @@ setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) +#' @rdname column_datetime_functions +#' @export +#' @name NULL +setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) + #' @rdname column_datetime_functions #' @export #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index d87f5d2705732..6cc0188dae95f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1418,6 +1418,8 @@ test_that("column functions", { c22 <- not(c) c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") + c24 <- date_trunc("hour", c) + date_trunc("minute", c) + date_trunc("week", c) + + date_trunc("quarter", c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1729,6 +1731,7 @@ test_that("date functions on a DataFrame", { expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + expect_equal(collect(select(df2, month(date_trunc("yyyy", df2$b))))[, 1], c(1, 1)) l3 <- list(list(a = 1000), list(a = -1000)) df3 <- createDataFrame(l3) From 1219d7a4343e837749c56492d382b3c814b97271 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 23 Dec 2017 10:27:14 -0800 Subject: [PATCH 493/712] [SPARK-22889][SPARKR] Set overwrite=T when install SparkR in tests ## What changes were proposed in this pull request? Since all CRAN checks go through the same machine, if there is an older partial download or partial install of Spark left behind the tests fail. This PR overwrites the install files when running tests. This shouldn't affect Jenkins as `SPARK_HOME` is set when running Jenkins tests. ## How was this patch tested? Test manually by running `R CMD check --as-cran` Author: Shivaram Venkataraman Closes #20060 from shivaram/sparkr-overwrite-cran. --- R/pkg/tests/run-all.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 63812ba70bb50..94d75188fb948 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -27,7 +27,10 @@ if (.Platform$OS.type == "windows") { # Setup global test environment # Install Spark first to set SPARK_HOME -install.spark() + +# NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on +# CRAN machines. For Jenkins we should already have SPARK_HOME set. +install.spark(overwrite = TRUE) sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") From 0bf1a74a773c79e66a67055298a36af477b9e21a Mon Sep 17 00:00:00 2001 From: sujithjay Date: Sun, 24 Dec 2017 11:14:30 -0800 Subject: [PATCH 494/712] [SPARK-22465][CORE] Add a safety-check to RDD defaultPartitioner ## What changes were proposed in this pull request? In choosing a Partitioner to use for a cogroup-like operation between a number of RDDs, the default behaviour was if some of the RDDs already have a partitioner, we choose the one amongst them with the maximum number of partitions. This behaviour, in some cases, could hit the 2G limit (SPARK-6235). To illustrate one such scenario, consider two RDDs: rDD1: with smaller data and smaller number of partitions, alongwith a Partitioner. rDD2: with much larger data and a larger number of partitions, without a Partitioner. The cogroup of these two RDDs could hit the 2G limit, as a larger amount of data is shuffled into a smaller number of partitions. This PR introduces a safety-check wherein the Partitioner is chosen only if either of the following conditions are met: 1. if the number of partitions of the RDD associated with the Partitioner is greater than or equal to the max number of upstream partitions; or 2. if the number of partitions of the RDD associated with the Partitioner is less than and within a single order of magnitude of the max number of upstream partitions. ## How was this patch tested? Unit tests in PartitioningSuite and PairRDDFunctionsSuite Author: sujithjay Closes #20002 from sujithjay/SPARK-22465. --- .../scala/org/apache/spark/Partitioner.scala | 32 +++++++++++++++++-- .../org/apache/spark/PartitioningSuite.scala | 27 ++++++++++++++++ .../spark/rdd/PairRDDFunctionsSuite.scala | 22 +++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index debbd8d7c26c9..437bbaae1968b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.math.log10 import scala.reflect.ClassTag import scala.util.hashing.byteswap32 @@ -42,7 +43,9 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, choose that one. + * If any of the RDDs already has a partitioner, and the number of partitions of the + * partitioner is either greater than or is less than and within a single order of + * magnitude of the max number of upstream partitions, choose that one. * * Otherwise, we use a default HashPartitioner. For the number of partitions, if * spark.default.parallelism is set, then we'll use the value from SparkContext @@ -57,8 +60,15 @@ object Partitioner { def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val rdds = (Seq(rdd) ++ others) val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) - if (hasPartitioner.nonEmpty) { - hasPartitioner.maxBy(_.partitions.length).partitioner.get + + val hasMaxPartitioner: Option[RDD[_]] = if (hasPartitioner.nonEmpty) { + Some(hasPartitioner.maxBy(_.partitions.length)) + } else { + None + } + + if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + hasMaxPartitioner.get.partitioner.get } else { if (rdd.context.conf.contains("spark.default.parallelism")) { new HashPartitioner(rdd.context.defaultParallelism) @@ -67,6 +77,22 @@ object Partitioner { } } } + + /** + * Returns true if the number of partitions of the RDD is either greater + * than or is less than and within a single order of magnitude of the + * max number of upstream partitions; + * otherwise, returns false + */ + private def isEligiblePartitioner( + hasMaxPartitioner: Option[RDD[_]], + rdds: Seq[RDD[_]]): Boolean = { + if (hasMaxPartitioner.isEmpty) { + return false + } + val maxPartitions = rdds.map(_.partitions.length).max + log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + } } /** diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index dfe4c25670ce0..155ca17db726b 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -259,6 +259,33 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva val partitioner = new RangePartitioner(22, rdd) assert(partitioner.numPartitions === 3) } + + test("defaultPartitioner") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc + .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + + assert(partitioner1.numPartitions == rdd1.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 65d35264dc108..a39e0469272fe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.size > 0) } + // See SPARK-22465 + test("cogroup between multiple RDD " + + "with an order of magnitude difference in number of partitions") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd1.getNumPartitions) + } + + // See SPARK-22465 + test("cogroup between multiple RDD" + + " with number of partitions similar in order of magnitude") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From fba03133d1369ce8cfebd6ae68b987a95f1b7ea1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 24 Dec 2017 22:57:53 -0800 Subject: [PATCH 495/712] [SPARK-22707][ML] Optimize CrossValidator memory occupation by models in fitting ## What changes were proposed in this pull request? Via some test I found CrossValidator still exists memory issue, it will still occupy `O(n*sizeof(model))` memory for holding models when fitting, if well optimized, it should be `O(parallelism*sizeof(model))` This is because modelFutures will hold the reference to model object after future is complete (we can use `future.value.get.get` to fetch it), and the `Future.sequence` and the `modelFutures` array holds references to each model future. So all model object are keep referenced. So it will still occupy `O(n*sizeof(model))` memory. I fix this by merging the `modelFuture` and `foldMetricFuture` together, and use `atomicInteger` to statistic complete fitting tasks and when all done, trigger `trainingDataset.unpersist`. I ever commented this issue on the old PR [SPARK-19357] https://github.com/apache/spark/pull/16774#pullrequestreview-53674264 unfortunately, at that time I do not realize that the issue still exists, but now I confirm it and create this PR to fix it. ## Discussion I give 3 approaches which we can compare, after discussion I realized none of them is ideal, we have to make a trade-off. **After discussion with jkbradley , choose approach 3** ### Approach 1 ~~The approach proposed by MrBago at~~ https://github.com/apache/spark/pull/19904#discussion_r156751569 ~~This approach resolve the model objects referenced issue, allow the model objects to be GCed in time. **BUT, in some cases, it still do not resolve the O(N) model memory occupation issue**. Let me use an extreme case to describe it:~~ ~~suppose we set `parallelism = 1`, and there're 100 paramMaps. So we have 100 fitting & evaluation tasks. In this approach, because of `parallelism = 1`, the code have to wait 100 fitting tasks complete, **(at this time the memory occupation by models already reach 100 * sizeof(model) )** and then it will unpersist training dataset and then do 100 evaluation tasks.~~ ### Approach 2 ~~This approach is my PR old version code~~ https://github.com/apache/spark/pull/19904/commits/2cc7c28f385009570536690d686f2843485942b2 ~~This approach can make sure at any case, the peak memory occupation by models to be `O(numParallelism * sizeof(model))`, but, it exists an issue that, in some extreme case, the "unpersist training dataset" will be delayed until most of the evaluation tasks complete. Suppose the case `parallelism = 1`, and there're 100 fitting & evaluation tasks, each fitting&evaluation task have to be executed one by one, so only after the first 99 fitting&evaluation tasks and the 100th fitting task complete, the "unpersist training dataset" will be triggered.~~ ### Approach 3 After I compared approach 1 and approach 2, I realized that, in the case which parallelism is low but there're many fitting & evaluation tasks, we cannot achieve both of the following two goals: - Make the peak memory occupation by models(driver-side) to be O(parallelism * sizeof(model)) - unpersist training dataset before most of the evaluation tasks started. So I vote for a simpler approach, move the unpersist training dataset to the end (Does this really matters ?) Because the goal 1 is more important, we must make sure the peak memory occupation by models (driver-side) to be O(parallelism * sizeof(model)), otherwise it will bring high risk of OOM. Like following code: ``` val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] //...other minor codes val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric metricformodeltrainedwithparamMap.") metric } (executionContext) } val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) trainingDataset.unpersist() // <------- unpersist at the end validationDataset.unpersist() ``` ## How was this patch tested? N/A Author: WeichenXu Closes #19904 from WeichenXu123/fix_cross_validator_memory_issue. --- .../apache/spark/ml/tuning/CrossValidator.scala | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1682ca91bf832..0130b3e255f0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -147,24 +147,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - if (collectSubModelsParam) { subModels.get(splitIndex)(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -174,6 +162,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + trainingDataset.unpersist() validationDataset.unpersist() foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits From 33ae2437ba4e634b510b0f96e914ad1ef4ccafd8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 25 Dec 2017 01:14:09 -0800 Subject: [PATCH 496/712] [SPARK-22893][SQL] Unified the data type mismatch message ## What changes were proposed in this pull request? We should use `dataType.simpleString` to unified the data type mismatch message: Before: ``` spark-sql> select cast(1 as binary); Error in query: cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7; ``` After: ``` park-sql> select cast(1 as binary); Error in query: cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7; ``` ## How was this patch tested? Exist test. Author: Yuming Wang Closes #20064 from wangyum/SPARK-22893. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../aggregate/ApproximatePercentile.scala | 4 +- .../expressions/conditionalExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 10 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../expressions/windowExpressions.scala | 8 +- .../native/binaryComparison.sql.out | 48 +-- .../typeCoercion/native/inConversion.sql.out | 280 +++++++++--------- .../native/windowFrameCoercion.sql.out | 8 +- .../sql-tests/results/window.sql.out | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 12 files changed, 189 insertions(+), 186 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5279d41278967..274d8813f16db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -181,7 +181,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") + s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 7facb9dad9a76..149ac265e6ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") + throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") } if (result.length == 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 142dfb02be0a8..b444c3a7be92a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -40,7 +40,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( - s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + "type of predicate expression in If should be boolean, " + + s"not ${predicate.dataType.simpleString}") } else if (!trueValue.dataType.sameType(falseValue.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 69af7a250a5ac..4f4d49166e88c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -155,8 +155,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + - s"Argument $i (${children(i).dataType})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + + s"Argument $i (${children(i).dataType.simpleString})") } } TypeCheckResult.TypeCheckSuccess @@ -249,7 +249,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function explode should be array or map type, not ${child.dataType}") + "input to function explode should be array or map type, " + + s"not ${child.dataType.simpleString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -378,7 +379,8 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should be array of struct type, not ${child.dataType}") + s"input to function $prettyName should be array of struct type, " + + s"not ${child.dataType.simpleString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f4ee3d10f3f43..b469f5cb7586a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -195,7 +195,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType} != ${mismatchOpt.get.dataType}") + s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") } } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 220cc4f885d7d..dd13d9a3bba51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -70,9 +70,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType}' used in the order specification does " + - s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + - "range frame.") + s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + "specification does not match the data type " + + s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +251,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType}' does not match " + + s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out index fe7bde040707c..2914d6015ea88 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/binaryComparison.sql.out @@ -16,7 +16,7 @@ SELECT cast(1 as binary) = '1' FROM t struct<> -- !query 1 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 2 @@ -25,7 +25,7 @@ SELECT cast(1 as binary) > '2' FROM t struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 3 @@ -34,7 +34,7 @@ SELECT cast(1 as binary) >= '2' FROM t struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 4 @@ -43,7 +43,7 @@ SELECT cast(1 as binary) < '2' FROM t struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 5 @@ -52,7 +52,7 @@ SELECT cast(1 as binary) <= '2' FROM t struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 6 @@ -61,7 +61,7 @@ SELECT cast(1 as binary) <> '2' FROM t struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 7 @@ -70,7 +70,7 @@ SELECT cast(1 as binary) = cast(null as string) FROM t struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 8 @@ -79,7 +79,7 @@ SELECT cast(1 as binary) > cast(null as string) FROM t struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 9 @@ -88,7 +88,7 @@ SELECT cast(1 as binary) >= cast(null as string) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 10 @@ -97,7 +97,7 @@ SELECT cast(1 as binary) < cast(null as string) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 11 @@ -106,7 +106,7 @@ SELECT cast(1 as binary) <= cast(null as string) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 12 @@ -115,7 +115,7 @@ SELECT cast(1 as binary) <> cast(null as string) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 7 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 7 -- !query 13 @@ -124,7 +124,7 @@ SELECT '1' = cast(1 as binary) FROM t struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 14 @@ -133,7 +133,7 @@ SELECT '2' > cast(1 as binary) FROM t struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 15 @@ -142,7 +142,7 @@ SELECT '2' >= cast(1 as binary) FROM t struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 16 @@ -151,7 +151,7 @@ SELECT '2' < cast(1 as binary) FROM t struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 13 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 13 -- !query 17 @@ -160,7 +160,7 @@ SELECT '2' <= cast(1 as binary) FROM t struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 18 @@ -169,7 +169,7 @@ SELECT '2' <> cast(1 as binary) FROM t struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 14 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 14 -- !query 19 @@ -178,7 +178,7 @@ SELECT cast(null as string) = cast(1 as binary) FROM t struct<> -- !query 19 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 20 @@ -187,7 +187,7 @@ SELECT cast(null as string) > cast(1 as binary) FROM t struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 21 @@ -196,7 +196,7 @@ SELECT cast(null as string) >= cast(1 as binary) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 22 @@ -205,7 +205,7 @@ SELECT cast(null as string) < cast(1 as binary) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 30 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 30 -- !query 23 @@ -214,7 +214,7 @@ SELECT cast(null as string) <= cast(1 as binary) FROM t struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 24 @@ -223,7 +223,7 @@ SELECT cast(null as string) <> cast(1 as binary) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast IntegerType to BinaryType; line 1 pos 31 +cannot resolve 'CAST(1 AS BINARY)' due to data type mismatch: cannot cast int to binary; line 1 pos 31 -- !query 25 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out index bf8ddee89b798..875ccc1341ec4 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out @@ -80,7 +80,7 @@ SELECT cast(1 as tinyint) in (cast('1' as binary)) FROM t struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 10 @@ -89,7 +89,7 @@ SELECT cast(1 as tinyint) in (cast(1 as boolean)) FROM t struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 11 @@ -98,7 +98,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 12 @@ -107,7 +107,7 @@ SELECT cast(1 as tinyint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 13 @@ -180,7 +180,7 @@ SELECT cast(1 as smallint) in (cast('1' as binary)) FROM t struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 22 @@ -189,7 +189,7 @@ SELECT cast(1 as smallint) in (cast(1 as boolean)) FROM t struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 23 @@ -198,7 +198,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 24 @@ -207,7 +207,7 @@ SELECT cast(1 as smallint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 24 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 25 @@ -280,7 +280,7 @@ SELECT cast(1 as int) in (cast('1' as binary)) FROM t struct<> -- !query 33 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 34 @@ -289,7 +289,7 @@ SELECT cast(1 as int) in (cast(1 as boolean)) FROM t struct<> -- !query 34 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 35 @@ -298,7 +298,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 35 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 36 @@ -307,7 +307,7 @@ SELECT cast(1 as int) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 36 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 37 @@ -380,7 +380,7 @@ SELECT cast(1 as bigint) in (cast('1' as binary)) FROM t struct<> -- !query 45 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 46 @@ -389,7 +389,7 @@ SELECT cast(1 as bigint) in (cast(1 as boolean)) FROM t struct<> -- !query 46 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 47 @@ -398,7 +398,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 47 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 48 @@ -407,7 +407,7 @@ SELECT cast(1 as bigint) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 48 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 49 @@ -480,7 +480,7 @@ SELECT cast(1 as float) in (cast('1' as binary)) FROM t struct<> -- !query 57 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 58 @@ -489,7 +489,7 @@ SELECT cast(1 as float) in (cast(1 as boolean)) FROM t struct<> -- !query 58 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 59 @@ -498,7 +498,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 59 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 60 @@ -507,7 +507,7 @@ SELECT cast(1 as float) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 60 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 61 @@ -580,7 +580,7 @@ SELECT cast(1 as double) in (cast('1' as binary)) FROM t struct<> -- !query 69 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 70 @@ -589,7 +589,7 @@ SELECT cast(1 as double) in (cast(1 as boolean)) FROM t struct<> -- !query 70 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 71 @@ -598,7 +598,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 71 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 72 @@ -607,7 +607,7 @@ SELECT cast(1 as double) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 72 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 73 @@ -680,7 +680,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('1' as binary)) FROM t struct<> -- !query 81 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 82 @@ -689,7 +689,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as boolean)) FROM t struct<> -- !query 82 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 83 @@ -698,7 +698,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00.0' as timestamp)) struct<> -- !query 83 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 84 @@ -707,7 +707,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 84 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 85 @@ -780,7 +780,7 @@ SELECT cast(1 as string) in (cast('1' as binary)) FROM t struct<> -- !query 93 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 94 @@ -789,7 +789,7 @@ SELECT cast(1 as string) in (cast(1 as boolean)) FROM t struct<> -- !query 94 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 95 @@ -814,7 +814,7 @@ SELECT cast('1' as binary) in (cast(1 as tinyint)) FROM t struct<> -- !query 97 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 98 @@ -823,7 +823,7 @@ SELECT cast('1' as binary) in (cast(1 as smallint)) FROM t struct<> -- !query 98 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 99 @@ -832,7 +832,7 @@ SELECT cast('1' as binary) in (cast(1 as int)) FROM t struct<> -- !query 99 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 100 @@ -841,7 +841,7 @@ SELECT cast('1' as binary) in (cast(1 as bigint)) FROM t struct<> -- !query 100 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 101 @@ -850,7 +850,7 @@ SELECT cast('1' as binary) in (cast(1 as float)) FROM t struct<> -- !query 101 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 102 @@ -859,7 +859,7 @@ SELECT cast('1' as binary) in (cast(1 as double)) FROM t struct<> -- !query 102 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 103 @@ -868,7 +868,7 @@ SELECT cast('1' as binary) in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 103 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 104 @@ -877,7 +877,7 @@ SELECT cast('1' as binary) in (cast(1 as string)) FROM t struct<> -- !query 104 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 105 @@ -894,7 +894,7 @@ SELECT cast('1' as binary) in (cast(1 as boolean)) FROM t struct<> -- !query 106 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 107 @@ -903,7 +903,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM struct<> -- !query 107 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 108 @@ -912,7 +912,7 @@ SELECT cast('1' as binary) in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 108 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 109 @@ -921,7 +921,7 @@ SELECT true in (cast(1 as tinyint)) FROM t struct<> -- !query 109 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 12 -- !query 110 @@ -930,7 +930,7 @@ SELECT true in (cast(1 as smallint)) FROM t struct<> -- !query 110 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 12 -- !query 111 @@ -939,7 +939,7 @@ SELECT true in (cast(1 as int)) FROM t struct<> -- !query 111 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 12 -- !query 112 @@ -948,7 +948,7 @@ SELECT true in (cast(1 as bigint)) FROM t struct<> -- !query 112 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 12 -- !query 113 @@ -957,7 +957,7 @@ SELECT true in (cast(1 as float)) FROM t struct<> -- !query 113 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 12 -- !query 114 @@ -966,7 +966,7 @@ SELECT true in (cast(1 as double)) FROM t struct<> -- !query 114 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 12 -- !query 115 @@ -975,7 +975,7 @@ SELECT true in (cast(1 as decimal(10, 0))) FROM t struct<> -- !query 115 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 12 -- !query 116 @@ -984,7 +984,7 @@ SELECT true in (cast(1 as string)) FROM t struct<> -- !query 116 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 12 +cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 12 -- !query 117 @@ -993,7 +993,7 @@ SELECT true in (cast('1' as binary)) FROM t struct<> -- !query 117 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 12 +cannot resolve '(true IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 12 -- !query 118 @@ -1010,7 +1010,7 @@ SELECT true in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t struct<> -- !query 119 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 12 -- !query 120 @@ -1019,7 +1019,7 @@ SELECT true in (cast('2017-12-11 09:30:00' as date)) FROM t struct<> -- !query 120 output org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 12 +cannot resolve '(true IN (CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 12 -- !query 121 @@ -1028,7 +1028,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as tinyint)) FROM t struct<> -- !query 121 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 122 @@ -1037,7 +1037,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as smallint)) FROM struct<> -- !query 122 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 123 @@ -1046,7 +1046,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as int)) FROM t struct<> -- !query 123 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 124 @@ -1055,7 +1055,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as bigint)) FROM t struct<> -- !query 124 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 125 @@ -1064,7 +1064,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as float)) FROM t struct<> -- !query 125 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 126 @@ -1073,7 +1073,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as double)) FROM t struct<> -- !query 126 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 127 @@ -1082,7 +1082,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as decimal(10, 0))) struct<> -- !query 127 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 128 @@ -1099,7 +1099,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2' as binary)) FROM struct<> -- !query 129 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 130 @@ -1108,7 +1108,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as boolean)) FROM t struct<> -- !query 130 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 131 @@ -1133,7 +1133,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as tinyint)) FROM t struct<> -- !query 133 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 134 @@ -1142,7 +1142,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as smallint)) FROM t struct<> -- !query 134 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 135 @@ -1151,7 +1151,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as int)) FROM t struct<> -- !query 135 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 136 @@ -1160,7 +1160,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as bigint)) FROM t struct<> -- !query 136 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 137 @@ -1169,7 +1169,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as float)) FROM t struct<> -- !query 137 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 138 @@ -1178,7 +1178,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as double)) FROM t struct<> -- !query 138 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 139 @@ -1187,7 +1187,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as decimal(10, 0))) FROM t struct<> -- !query 139 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 140 @@ -1204,7 +1204,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2' as binary)) FROM t struct<> -- !query 141 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 142 @@ -1213,7 +1213,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as boolean)) FROM t struct<> -- !query 142 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 143 @@ -1302,7 +1302,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('1' as binary)) FROM t struct<> -- !query 153 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ByteType != BinaryType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: tinyint != binary; line 1 pos 26 -- !query 154 @@ -1311,7 +1311,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as boolean)) FROM t struct<> -- !query 154 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ByteType != BooleanType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: tinyint != boolean; line 1 pos 26 -- !query 155 @@ -1320,7 +1320,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00.0' a struct<> -- !query 155 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ByteType != TimestampType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: tinyint != timestamp; line 1 pos 26 -- !query 156 @@ -1329,7 +1329,7 @@ SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast('2017-12-11 09:30:00' as struct<> -- !query 156 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ByteType != DateType; line 1 pos 26 +cannot resolve '(CAST(1 AS TINYINT) IN (CAST(1 AS TINYINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: tinyint != date; line 1 pos 26 -- !query 157 @@ -1402,7 +1402,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('1' as binary)) FROM t struct<> -- !query 165 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: ShortType != BinaryType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: smallint != binary; line 1 pos 27 -- !query 166 @@ -1411,7 +1411,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as boolean)) FROM t struct<> -- !query 166 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: ShortType != BooleanType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: smallint != boolean; line 1 pos 27 -- !query 167 @@ -1420,7 +1420,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00.0' struct<> -- !query 167 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: ShortType != TimestampType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: smallint != timestamp; line 1 pos 27 -- !query 168 @@ -1429,7 +1429,7 @@ SELECT cast(1 as smallint) in (cast(1 as smallint), cast('2017-12-11 09:30:00' a struct<> -- !query 168 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: ShortType != DateType; line 1 pos 27 +cannot resolve '(CAST(1 AS SMALLINT) IN (CAST(1 AS SMALLINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: smallint != date; line 1 pos 27 -- !query 169 @@ -1502,7 +1502,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('1' as binary)) FROM t struct<> -- !query 177 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BinaryType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: int != binary; line 1 pos 22 -- !query 178 @@ -1511,7 +1511,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast(1 as boolean)) FROM t struct<> -- !query 178 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: IntegerType != BooleanType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: int != boolean; line 1 pos 22 -- !query 179 @@ -1520,7 +1520,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00.0' as timest struct<> -- !query 179 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: IntegerType != TimestampType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: int != timestamp; line 1 pos 22 -- !query 180 @@ -1529,7 +1529,7 @@ SELECT cast(1 as int) in (cast(1 as int), cast('2017-12-11 09:30:00' as date)) F struct<> -- !query 180 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: IntegerType != DateType; line 1 pos 22 +cannot resolve '(CAST(1 AS INT) IN (CAST(1 AS INT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: int != date; line 1 pos 22 -- !query 181 @@ -1602,7 +1602,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('1' as binary)) FROM t struct<> -- !query 189 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: LongType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: bigint != binary; line 1 pos 25 -- !query 190 @@ -1611,7 +1611,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as boolean)) FROM t struct<> -- !query 190 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: LongType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: bigint != boolean; line 1 pos 25 -- !query 191 @@ -1620,7 +1620,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00.0' as struct<> -- !query 191 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: LongType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: bigint != timestamp; line 1 pos 25 -- !query 192 @@ -1629,7 +1629,7 @@ SELECT cast(1 as bigint) in (cast(1 as bigint), cast('2017-12-11 09:30:00' as da struct<> -- !query 192 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: LongType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS BIGINT) IN (CAST(1 AS BIGINT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: bigint != date; line 1 pos 25 -- !query 193 @@ -1702,7 +1702,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('1' as binary)) FROM t struct<> -- !query 201 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: FloatType != BinaryType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: float != binary; line 1 pos 24 -- !query 202 @@ -1711,7 +1711,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast(1 as boolean)) FROM t struct<> -- !query 202 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: FloatType != BooleanType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: float != boolean; line 1 pos 24 -- !query 203 @@ -1720,7 +1720,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 203 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: FloatType != TimestampType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: float != timestamp; line 1 pos 24 -- !query 204 @@ -1729,7 +1729,7 @@ SELECT cast(1 as float) in (cast(1 as float), cast('2017-12-11 09:30:00' as date struct<> -- !query 204 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: FloatType != DateType; line 1 pos 24 +cannot resolve '(CAST(1 AS FLOAT) IN (CAST(1 AS FLOAT), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: float != date; line 1 pos 24 -- !query 205 @@ -1802,7 +1802,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('1' as binary)) FROM t struct<> -- !query 213 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: double != binary; line 1 pos 25 -- !query 214 @@ -1811,7 +1811,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast(1 as boolean)) FROM t struct<> -- !query 214 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DoubleType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: double != boolean; line 1 pos 25 -- !query 215 @@ -1820,7 +1820,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00.0' as struct<> -- !query 215 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DoubleType != TimestampType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: double != timestamp; line 1 pos 25 -- !query 216 @@ -1829,7 +1829,7 @@ SELECT cast(1 as double) in (cast(1 as double), cast('2017-12-11 09:30:00' as da struct<> -- !query 216 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DoubleType != DateType; line 1 pos 25 +cannot resolve '(CAST(1 AS DOUBLE) IN (CAST(1 AS DOUBLE), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: double != date; line 1 pos 25 -- !query 217 @@ -1902,7 +1902,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('1' as bina struct<> -- !query 225 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BinaryType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != binary; line 1 pos 33 -- !query 226 @@ -1911,7 +1911,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as boolea struct<> -- !query 226 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != BooleanType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != boolean; line 1 pos 33 -- !query 227 @@ -1920,7 +1920,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 227 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != TimestampType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != timestamp; line 1 pos 33 -- !query 228 @@ -1929,7 +1929,7 @@ SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast('2017-12-11 struct<> -- !query 228 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: DecimalType(10,0) != DateType; line 1 pos 33 +cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST(1 AS DECIMAL(10,0)), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: decimal(10,0) != date; line 1 pos 33 -- !query 229 @@ -2002,7 +2002,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast('1' as binary)) FROM t struct<> -- !query 237 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: StringType != BinaryType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 -- !query 238 @@ -2011,7 +2011,7 @@ SELECT cast(1 as string) in (cast(1 as string), cast(1 as boolean)) FROM t struct<> -- !query 238 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: StringType != BooleanType; line 1 pos 25 +cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS STRING), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 -- !query 239 @@ -2036,7 +2036,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as tinyint)) FROM t struct<> -- !query 241 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ByteType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: binary != tinyint; line 1 pos 27 -- !query 242 @@ -2045,7 +2045,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as smallint)) FROM t struct<> -- !query 242 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != ShortType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: binary != smallint; line 1 pos 27 -- !query 243 @@ -2054,7 +2054,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as int)) FROM t struct<> -- !query 243 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != IntegerType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: binary != int; line 1 pos 27 -- !query 244 @@ -2063,7 +2063,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as bigint)) FROM t struct<> -- !query 244 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != LongType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: binary != bigint; line 1 pos 27 -- !query 245 @@ -2072,7 +2072,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as float)) FROM t struct<> -- !query 245 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BinaryType != FloatType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: binary != float; line 1 pos 27 -- !query 246 @@ -2081,7 +2081,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as double)) FROM t struct<> -- !query 246 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DoubleType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: binary != double; line 1 pos 27 -- !query 247 @@ -2090,7 +2090,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as decimal(10, 0))) F struct<> -- !query 247 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BinaryType != DecimalType(10,0); line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: binary != decimal(10,0); line 1 pos 27 -- !query 248 @@ -2099,7 +2099,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as string)) FROM t struct<> -- !query 248 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BinaryType != StringType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 -- !query 249 @@ -2116,7 +2116,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast(1 as boolean)) FROM t struct<> -- !query 250 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: BinaryType != BooleanType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: binary != boolean; line 1 pos 27 -- !query 251 @@ -2125,7 +2125,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00.0' struct<> -- !query 251 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BinaryType != TimestampType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: binary != timestamp; line 1 pos 27 -- !query 252 @@ -2134,7 +2134,7 @@ SELECT cast('1' as binary) in (cast('1' as binary), cast('2017-12-11 09:30:00' a struct<> -- !query 252 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BinaryType != DateType; line 1 pos 27 +cannot resolve '(CAST('1' AS BINARY) IN (CAST('1' AS BINARY), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: binary != date; line 1 pos 27 -- !query 253 @@ -2143,7 +2143,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as tinyint)) FROM t struct<> -- !query 253 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ByteType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: boolean != tinyint; line 1 pos 28 -- !query 254 @@ -2152,7 +2152,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as smallint)) FROM struct<> -- !query 254 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != ShortType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: boolean != smallint; line 1 pos 28 -- !query 255 @@ -2161,7 +2161,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as int)) FROM t struct<> -- !query 255 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != IntegerType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: boolean != int; line 1 pos 28 -- !query 256 @@ -2170,7 +2170,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as bigint)) FROM t struct<> -- !query 256 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != LongType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: boolean != bigint; line 1 pos 28 -- !query 257 @@ -2179,7 +2179,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as float)) FROM t struct<> -- !query 257 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: BooleanType != FloatType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: boolean != float; line 1 pos 28 -- !query 258 @@ -2188,7 +2188,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as double)) FROM t struct<> -- !query 258 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DoubleType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: boolean != double; line 1 pos 28 -- !query 259 @@ -2197,7 +2197,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as decimal(10, 0))) struct<> -- !query 259 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: BooleanType != DecimalType(10,0); line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: boolean != decimal(10,0); line 1 pos 28 -- !query 260 @@ -2206,7 +2206,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast(1 as string)) FROM t struct<> -- !query 260 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: BooleanType != StringType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 28 -- !query 261 @@ -2215,7 +2215,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('1' as binary)) FROM struct<> -- !query 261 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: BooleanType != BinaryType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: boolean != binary; line 1 pos 28 -- !query 262 @@ -2232,7 +2232,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00. struct<> -- !query 263 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: BooleanType != TimestampType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00.0' AS TIMESTAMP)))' due to data type mismatch: Arguments must be same type but were: boolean != timestamp; line 1 pos 28 -- !query 264 @@ -2241,7 +2241,7 @@ SELECT cast('1' as boolean) in (cast('1' as boolean), cast('2017-12-11 09:30:00' struct<> -- !query 264 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: BooleanType != DateType; line 1 pos 28 +cannot resolve '(CAST('1' AS BOOLEAN) IN (CAST('1' AS BOOLEAN), CAST('2017-12-11 09:30:00' AS DATE)))' due to data type mismatch: Arguments must be same type but were: boolean != date; line 1 pos 28 -- !query 265 @@ -2250,7 +2250,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 265 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ByteType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != tinyint; line 1 pos 50 -- !query 266 @@ -2259,7 +2259,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 266 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != ShortType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != smallint; line 1 pos 50 -- !query 267 @@ -2268,7 +2268,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 267 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != IntegerType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: timestamp != int; line 1 pos 50 -- !query 268 @@ -2277,7 +2277,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 268 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != LongType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: timestamp != bigint; line 1 pos 50 -- !query 269 @@ -2286,7 +2286,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 269 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: TimestampType != FloatType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: timestamp != float; line 1 pos 50 -- !query 270 @@ -2295,7 +2295,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 270 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: TimestampType != DoubleType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: timestamp != double; line 1 pos 50 -- !query 271 @@ -2304,7 +2304,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 271 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: TimestampType != DecimalType(10,0); line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: timestamp != decimal(10,0); line 1 pos 50 -- !query 272 @@ -2321,7 +2321,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 273 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BinaryType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: timestamp != binary; line 1 pos 50 -- !query 274 @@ -2330,7 +2330,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00. struct<> -- !query 274 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: TimestampType != BooleanType; line 1 pos 50 +cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 09:30:00.0' AS TIMESTAMP), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: timestamp != boolean; line 1 pos 50 -- !query 275 @@ -2355,7 +2355,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 277 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ByteType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS TINYINT)))' due to data type mismatch: Arguments must be same type but were: date != tinyint; line 1 pos 43 -- !query 278 @@ -2364,7 +2364,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 278 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: DateType != ShortType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS SMALLINT)))' due to data type mismatch: Arguments must be same type but were: date != smallint; line 1 pos 43 -- !query 279 @@ -2373,7 +2373,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 279 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: DateType != IntegerType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS INT)))' due to data type mismatch: Arguments must be same type but were: date != int; line 1 pos 43 -- !query 280 @@ -2382,7 +2382,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 280 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: DateType != LongType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BIGINT)))' due to data type mismatch: Arguments must be same type but were: date != bigint; line 1 pos 43 -- !query 281 @@ -2391,7 +2391,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 281 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: DateType != FloatType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS FLOAT)))' due to data type mismatch: Arguments must be same type but were: date != float; line 1 pos 43 -- !query 282 @@ -2400,7 +2400,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 282 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: DateType != DoubleType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DOUBLE)))' due to data type mismatch: Arguments must be same type but were: date != double; line 1 pos 43 -- !query 283 @@ -2409,7 +2409,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 283 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: DateType != DecimalType(10,0); line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: Arguments must be same type but were: date != decimal(10,0); line 1 pos 43 -- !query 284 @@ -2426,7 +2426,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 285 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: DateType != BinaryType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: date != binary; line 1 pos 43 -- !query 286 @@ -2435,7 +2435,7 @@ SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as dat struct<> -- !query 286 output org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: DateType != BooleanType; line 1 pos 43 +cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30:00' AS DATE), CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: date != boolean; line 1 pos 43 -- !query 287 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out index 5dd257ba6a0bb..01d83938031fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'StringType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 21 @@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'BinaryType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('1' AS BINARY) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'binary' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 22 @@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'BooleanType' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or calendarinterval)'.; line 1 pos 21 -- !query 23 @@ -195,7 +195,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('2017-12-11 09:30:00.0' as ti struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 21 +cannot resolve '(PARTITION BY 1 ORDER BY CAST('2017-12-11 09:30:00.0' AS TIMESTAMP) DESC NULLS LAST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 21 -- !query 24 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index a52e198eb9a8f..133458ae9303b 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -61,7 +61,7 @@ ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType' does not match the expected data type 'int'.; line 1 pos 41 +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'bigint' does not match the expected data type 'int'.; line 1 pos 41 -- !query 4 @@ -221,7 +221,7 @@ RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'TimestampType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_timestamp() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'timestamp' used in the order specification does not match the data type 'int' which is used in the range frame.; line 1 pos 33 -- !query 15 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bd1e7adefc7a9..d535896723bd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -660,7 +660,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.as[KryoData] }.message - assert(e.contains("cannot cast IntegerType to BinaryType")) + assert(e.contains("cannot cast int to binary")) } test("Java encoder") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 6b98209fd49b8..109fcf90a3ec9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -65,7 +65,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m3 = intercept[AnalysisException] { df.selectExpr("stack(2, 1, '2.2')") }.getMessage - assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)")) + assert(m3.contains("data type mismatch: Argument 1 (int) != Argument 2 (string)")) // stack on column data val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") @@ -80,7 +80,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val m5 = intercept[AnalysisException] { df3.selectExpr("stack(2, a, b)") }.getMessage - assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)")) + assert(m5.contains("data type mismatch: Argument 1 (int) != Argument 2 (double)")) } From 12d20dd75b1620da362dbb5345bed58e47ddacb9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 25 Dec 2017 20:29:10 +0900 Subject: [PATCH 497/712] [SPARK-22874][PYSPARK][SQL][FOLLOW-UP] Modify error messages to show actual versions. ## What changes were proposed in this pull request? This is a follow-up pr of #20054 modifying error messages for both pandas and pyarrow to show actual versions. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20074 from ueshin/issues/SPARK-22874_fup1. --- python/pyspark/sql/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index fb7d42a35d8f4..08c34c6dccc5e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -118,7 +118,8 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion import pandas if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " + "however, your version was %s." % pandas.__version__) def require_minimum_pyarrow_version(): @@ -127,4 +128,5 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion import pyarrow if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process") + raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " + "however, your version was %s." % pyarrow.__version__) From be03d3ad793dfe8f3ec33074d4ae95f5adb86ee4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 25 Dec 2017 16:17:39 -0800 Subject: [PATCH 498/712] [SPARK-22893][SQL][HOTFIX] Fix a error message of VersionsSuite ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20064 breaks Jenkins tests because it missed to update one error message for Hive 0.12 and Hive 0.13. This PR fixes that. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/3924/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/3977/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/4226/ - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4260/ ## How was this patch tested? Pass the Jenkins without failure. Author: Dongjoon Hyun Closes #20079 from dongjoon-hyun/SPARK-22893. --- .../scala/org/apache/spark/sql/hive/client/VersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9d15dabc8d3f5..94473a08dd317 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -773,7 +773,7 @@ class VersionsSuite extends SparkFunSuite with Logging { """.stripMargin ) - val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType" + val errorMsg = "data type mismatch: cannot cast decimal(2,1) to binary" if (isPartitioned) { val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" From 0e6833006d28df426eb132bb8fc82917b8e2aedd Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Tue, 26 Dec 2017 09:50:39 +0200 Subject: [PATCH 499/712] [SPARK-20168][DSTREAM] Add changes to use kinesis fetches from specific timestamp ## What changes were proposed in this pull request? Kinesis client can resume from a specified timestamp while creating a stream. We should have option to pass a timestamp in config to allow kinesis to resume from the given timestamp. The patch introduces a new `KinesisInitialPositionInStream` that takes the `InitialPositionInStream` with the `timestamp` information that can be used to resume kinesis fetches from the provided timestamp. ## How was this patch tested? Unit Tests cc : budde brkyvz Author: Yash Sharma Closes #18029 from yssharma/ysharma/kcl_resume. --- .../kinesis/KinesisInitialPositions.java | 91 +++++++++++++++++++ .../streaming/KinesisWordCountASL.scala | 5 +- .../kinesis/KinesisInputDStream.scala | 31 +++++-- .../streaming/kinesis/KinesisReceiver.scala | 45 +++++---- .../streaming/kinesis/KinesisUtils.scala | 15 ++- .../JavaKinesisInputDStreamBuilderSuite.java | 47 ++++++++-- .../KinesisInputDStreamBuilderSuite.scala | 68 +++++++++++++- .../kinesis/KinesisStreamSuite.scala | 11 ++- 8 files changed, 264 insertions(+), 49 deletions(-) create mode 100644 external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java new file mode 100644 index 0000000000000..206e1e4699030 --- /dev/null +++ b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import java.io.Serializable; +import java.util.Date; + +/** + * A java wrapper for exposing [[InitialPositionInStream]] + * to the corresponding Kinesis readers. + */ +interface KinesisInitialPosition { + InitialPositionInStream getPosition(); +} + +public class KinesisInitialPositions { + public static class Latest implements KinesisInitialPosition, Serializable { + public Latest() {} + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.LATEST; + } + } + + public static class TrimHorizon implements KinesisInitialPosition, Serializable { + public TrimHorizon() {} + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.TRIM_HORIZON; + } + } + + public static class AtTimestamp implements KinesisInitialPosition, Serializable { + private Date timestamp; + + public AtTimestamp(Date timestamp) { + this.timestamp = timestamp; + } + + @Override + public InitialPositionInStream getPosition() { + return InitialPositionInStream.AT_TIMESTAMP; + } + + public Date getTimestamp() { + return timestamp; + } + } + + + /** + * Returns instance of [[KinesisInitialPosition]] based on the passed [[InitialPositionInStream]]. + * This method is used in KinesisUtils for translating the InitialPositionInStream + * to InitialPosition. This function would be removed when we deprecate the KinesisUtils. + * + * @return [[InitialPosition]] + */ + public static KinesisInitialPosition fromKinesisInitialPosition( + InitialPositionInStream initialPositionInStream) throws UnsupportedOperationException { + if (initialPositionInStream == InitialPositionInStream.LATEST) { + return new Latest(); + } else if (initialPositionInStream == InitialPositionInStream.TRIM_HORIZON) { + return new TrimHorizon(); + } else { + // InitialPositionInStream.AT_TIMESTAMP is not supported. + // Use InitialPosition.atTimestamp(timestamp) instead. + throw new UnsupportedOperationException( + "Only InitialPositionInStream.LATEST and InitialPositionInStream.TRIM_HORIZON " + + "supported in initialPositionInStream(). Please use the initialPosition() from " + + "builder API in KinesisInputDStream for using InitialPositionInStream.AT_TIMESTAMP"); + } + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index cde2c4b04c0c7..fcb790e3ea1f9 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -24,7 +24,6 @@ import scala.util.Random import com.amazonaws.auth.DefaultAWSCredentialsProviderChain import com.amazonaws.services.kinesis.AmazonKinesisClient -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest import org.apache.log4j.{Level, Logger} @@ -33,9 +32,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisInputDStream - /** * Consumes messages from a Amazon Kinesis streams and does wordcount. * @@ -139,7 +138,7 @@ object KinesisWordCountASL extends Logging { .streamName(streamName) .endpointUrl(endpointUrl) .regionName(regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointAppName(appName) .checkpointInterval(kinesisCheckpointInterval) .storageLevel(StorageLevel.MEMORY_AND_DISK_2) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index f61e398e7bdd1..1ffec01df9f00 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo @@ -36,7 +37,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( val streamName: String, val endpointUrl: String, val regionName: String, - val initialPositionInStream: InitialPositionInStream, + val initialPosition: KinesisInitialPosition, val checkpointAppName: String, val checkpointInterval: Duration, val _storageLevel: StorageLevel, @@ -77,7 +78,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } override def getReceiver(): Receiver[T] = { - new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, + new KinesisReceiver(streamName, endpointUrl, regionName, initialPosition, checkpointAppName, checkpointInterval, _storageLevel, messageHandler, kinesisCreds, dynamoDBCreds, cloudWatchCreds) } @@ -100,7 +101,7 @@ object KinesisInputDStream { // Params with defaults private var endpointUrl: Option[String] = None private var regionName: Option[String] = None - private var initialPositionInStream: Option[InitialPositionInStream] = None + private var initialPosition: Option[KinesisInitialPosition] = None private var checkpointInterval: Option[Duration] = None private var storageLevel: Option[StorageLevel] = None private var kinesisCredsProvider: Option[SparkAWSCredentials] = None @@ -180,16 +181,32 @@ object KinesisInputDStream { this } + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[KinesisInitialPositions.Latest]] if no custom value is specified. + * + * @param initialPosition [[KinesisInitialPosition]] value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPosition(initialPosition: KinesisInitialPosition): Builder = { + this.initialPosition = Option(initialPosition) + this + } + /** * Sets the initial position data is read from in the Kinesis stream. Defaults to * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * This function would be removed when we deprecate the KinesisUtils. * * @param initialPosition InitialPositionInStream value specifying where Spark Streaming * will start reading records in the Kinesis stream from * @return Reference to this [[KinesisInputDStream.Builder]] */ + @deprecated("use initialPosition(initialPosition: KinesisInitialPosition)", "2.3.0") def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { - initialPositionInStream = Option(initialPosition) + this.initialPosition = Option( + KinesisInitialPositions.fromKinesisInitialPosition(initialPosition)) this } @@ -266,7 +283,7 @@ object KinesisInputDStream { getRequiredParam(streamName, "streamName"), endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), - initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + initialPosition.getOrElse(DEFAULT_INITIAL_POSITION), getRequiredParam(checkpointAppName, "checkpointAppName"), checkpointInterval.getOrElse(ssc.graph.batchDuration), storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), @@ -293,7 +310,6 @@ object KinesisInputDStream { * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. * * @since 2.2.0 - * * @return [[KinesisInputDStream.Builder]] instance */ def builder: Builder = new Builder @@ -309,7 +325,6 @@ object KinesisInputDStream { private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = "https://kinesis.us-east-1.amazonaws.com" private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" - private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = - InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_INITIAL_POSITION: KinesisInitialPosition = new Latest() private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1026d0fcb59bd..fa0de6298a5f1 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -24,12 +24,13 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.AtTimestamp import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils @@ -56,12 +57,13 @@ import org.apache.spark.util.Utils * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Region name used by the Kinesis Client Library for * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). + * @param initialPosition Instance of [[KinesisInitialPosition]] + * In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * ([[KinesisInitialPositions.TrimHorizon]]) or + * the tip of the stream ([[KinesisInitialPositions.Latest]]). * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, * the KCL will throw errors. This usually requires deleting the backing @@ -83,7 +85,7 @@ private[kinesis] class KinesisReceiver[T]( val streamName: String, endpointUrl: String, regionName: String, - initialPositionInStream: InitialPositionInStream, + initialPosition: KinesisInitialPosition, checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, @@ -148,18 +150,29 @@ private[kinesis] class KinesisReceiver[T]( kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) val kinesisProvider = kinesisCreds.provider - val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( - checkpointAppName, - streamName, - kinesisProvider, - dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), - cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), - workerId) + + val kinesisClientLibConfiguration = { + val baseClientLibConfiguration = new KinesisClientLibConfiguration( + checkpointAppName, + streamName, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), + workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) + .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) + // Update the Kinesis client lib config with timestamp + // if InitialPositionInStream.AT_TIMESTAMP is passed + initialPosition match { + case ts: AtTimestamp => + baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) + case _ => baseClientLibConfiguration + } + } + /* * RecordProcessorFactory creates impls of IRecordProcessor. * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 1298463bfba1e..2500460bd330b 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -73,7 +73,8 @@ object KinesisUtils { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, DefaultCredentials, None, None) } } @@ -129,7 +130,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, kinesisCredsProvider, None, None) } } @@ -198,7 +200,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, cleanedHandler, kinesisCredsProvider, None, None) } } @@ -243,7 +246,8 @@ object KinesisUtils { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -293,7 +297,8 @@ object KinesisUtils { awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), + kinesisAppName, checkpointInterval, storageLevel, KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java index be6d549b1e429..03becd73d1a06 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -17,14 +17,13 @@ package org.apache.spark.streaming.kinesis; -import org.junit.Test; - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; - +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.TrimHorizon; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.Seconds; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.Seconds; +import org.junit.Test; public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { /** @@ -35,7 +34,41 @@ public void testJavaKinesisDStreamBuilder() { String streamName = "a-very-nice-stream-name"; String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; String region = "us-west-2"; - InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + KinesisInitialPosition initialPosition = new TrimHorizon(); + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPosition(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPosition().getPosition() == initialPosition.getPosition()); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } + + /** + * Test to ensure that the old API for InitialPositionInStream + * is supported in KinesisDStream.Builder. + * This test would be removed when we deprecate the KinesisUtils. + */ + @Test + public void testJavaKinesisDStreamBuilderOldApi() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; String appName = "a-very-nice-kinesis-app"; Duration checkpointInterval = Seconds.apply(30); StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); @@ -45,7 +78,7 @@ public void testJavaKinesisDStreamBuilder() { .streamName(streamName) .endpointUrl(endpointUrl) .regionName(region) - .initialPositionInStream(initialPosition) + .initialPositionInStream(InitialPositionInStream.LATEST) .checkpointAppName(appName) .checkpointInterval(checkpointInterval) .storageLevel(storageLevel) @@ -53,7 +86,7 @@ public void testJavaKinesisDStreamBuilder() { assert(kinesisDStream.streamName() == streamName); assert(kinesisDStream.endpointUrl() == endpointUrl); assert(kinesisDStream.regionName() == region); - assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.initialPosition().getPosition() == InitialPositionInStream.LATEST); assert(kinesisDStream.checkpointAppName() == appName); assert(kinesisDStream.checkpointInterval() == checkpointInterval); assert(kinesisDStream._storageLevel() == storageLevel); diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index afa1a7f8ca663..e0e26847aa0ec 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming.kinesis +import java.util.Calendar + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Duration, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.{AtTimestamp, TrimHorizon} class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach with MockitoSugar { @@ -69,7 +72,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val dstream = builder.build() assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) - assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.initialPosition == DEFAULT_INITIAL_POSITION) assert(dstream.checkpointInterval == batchDuration) assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) assert(dstream.kinesisCreds == DefaultCredentials) @@ -80,7 +83,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE test("should propagate custom non-auth values to KinesisInputDStream") { val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" val customRegion = "us-west-2" - val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customInitialPosition = new TrimHorizon() val customAppName = "a-very-nice-kinesis-app" val customCheckpointInterval = Seconds(30) val customStorageLevel = StorageLevel.MEMORY_ONLY @@ -91,7 +94,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val dstream = builder .endpointUrl(customEndpointUrl) .regionName(customRegion) - .initialPositionInStream(customInitialPosition) + .initialPosition(customInitialPosition) .checkpointAppName(customAppName) .checkpointInterval(customCheckpointInterval) .storageLevel(customStorageLevel) @@ -101,12 +104,67 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .build() assert(dstream.endpointUrl == customEndpointUrl) assert(dstream.regionName == customRegion) - assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.initialPosition == customInitialPosition) assert(dstream.checkpointAppName == customAppName) assert(dstream.checkpointInterval == customCheckpointInterval) assert(dstream._storageLevel == customStorageLevel) assert(dstream.kinesisCreds == customKinesisCreds) assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + + // Testing with AtTimestamp + val cal = Calendar.getInstance() + cal.add(Calendar.DATE, -1) + val timestamp = cal.getTime() + val initialPositionAtTimestamp = new AtTimestamp(timestamp) + + val dstreamAtTimestamp = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPosition(initialPositionAtTimestamp) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstreamAtTimestamp.endpointUrl == customEndpointUrl) + assert(dstreamAtTimestamp.regionName == customRegion) + assert(dstreamAtTimestamp.initialPosition.getPosition + == initialPositionAtTimestamp.getPosition) + assert( + dstreamAtTimestamp.initialPosition.asInstanceOf[AtTimestamp].getTimestamp.equals(timestamp)) + assert(dstreamAtTimestamp.checkpointAppName == customAppName) + assert(dstreamAtTimestamp.checkpointInterval == customCheckpointInterval) + assert(dstreamAtTimestamp._storageLevel == customStorageLevel) + assert(dstreamAtTimestamp.kinesisCreds == customKinesisCreds) + assert(dstreamAtTimestamp.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstreamAtTimestamp.cloudWatchCreds == Option(customCloudWatchCreds)) + } + + test("old Api should throw UnsupportedOperationExceptionexception with AT_TIMESTAMP") { + val streamName: String = "a-very-nice-stream-name" + val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com" + val region: String = "us-west-2" + val appName: String = "a-very-nice-kinesis-app" + val checkpointInterval: Duration = Seconds.apply(30) + val storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + + // This should not build. + // InitialPositionInStream.AT_TIMESTAMP is not supported in old Api. + // The builder Api in KinesisInputDStream should be used. + intercept[UnsupportedOperationException] { + val kinesisDStream: KinesisInputDStream[Array[Byte]] = KinesisInputDStream.builder + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(InitialPositionInStream.AT_TIMESTAMP) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build + } } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 7e5bda923f63e..a7a68eba910bf 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -178,7 +179,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -209,7 +210,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .buildWithMessageHandler(addFive(_)) @@ -245,7 +246,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName("dummyStream") .endpointUrl(dummyEndpointUrl) .regionName(dummyRegionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -293,7 +294,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(localTestUtils.streamName) .endpointUrl(localTestUtils.endpointUrl) .regionName(localTestUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() @@ -369,7 +370,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun .streamName(testUtils.streamName) .endpointUrl(testUtils.endpointUrl) .regionName(testUtils.regionName) - .initialPositionInStream(InitialPositionInStream.LATEST) + .initialPosition(new Latest()) .checkpointInterval(Seconds(10)) .storageLevel(StorageLevel.MEMORY_ONLY) .build() From eb386be1ed383323da6e757f63f3b8a7ced38cc4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 26 Dec 2017 21:37:25 +0900 Subject: [PATCH 500/712] [SPARK-21552][SQL] Add DecimalType support to ArrowWriter. ## What changes were proposed in this pull request? Decimal type is not yet supported in `ArrowWriter`. This is adding the decimal type support. ## How was this patch tested? Added a test to `ArrowConvertersSuite`. Author: Takuya UESHIN Closes #18754 from ueshin/issues/SPARK-21552. --- python/pyspark/sql/tests.py | 61 +++++++++++------ python/pyspark/sql/types.py | 2 +- .../sql/execution/arrow/ArrowWriter.scala | 21 ++++++ .../arrow/ArrowConvertersSuite.scala | 67 ++++++++++++++++++- .../execution/arrow/ArrowWriterSuite.scala | 2 + 5 files changed, 131 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b977160af566d..b811a0f8a31d3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime + from decimal import Decimal ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3158,11 +3159,15 @@ def setUpClass(cls): StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3190,10 +3195,11 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_unsupported_datatype(self): - schema = StructType([StructField("decimal", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.toPandas() def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3293,7 +3299,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) @@ -3317,11 +3323,11 @@ def test_createDataFrame_with_incorrect_schema(self): def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3344,7 +3350,7 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3514,6 +3520,7 @@ def test_vectorized_udf_basic(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, StringType()) @@ -3521,10 +3528,12 @@ def test_vectorized_udf_basic(self): long_f = pandas_udf(f, LongType()) float_f = pandas_udf(f, FloatType()) double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_boolean(self): @@ -3590,6 +3599,16 @@ def test_vectorized_udf_null_double(self): res = df.select(double_f(col('double'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_string(self): from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] @@ -3607,6 +3626,7 @@ def test_vectorized_udf_datatype_string(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, 'string') @@ -3614,10 +3634,12 @@ def test_vectorized_udf_datatype_string(self): long_f = pandas_udf(f, 'long') float_f = pandas_udf(f, 'float') double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') bool_f = pandas_udf(f, 'boolean') res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_complex(self): @@ -3713,12 +3735,12 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, DecimalType()) + f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('dt'))).collect() + df.select(f(col('map'))).collect() def test_vectorized_udf_null_date(self): from pyspark.sql.functions import pandas_udf, col @@ -4012,7 +4034,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + [StructField("id", LongType(), True), + StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 063264a89379c..02b245713b6ab 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,7 +1617,7 @@ def to_arrow_type(dt): elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: - arrow_type = pa.decimal(dt.precision, dt.scale) + arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == DateType: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0258056d9de49..22b63513548fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,6 +53,8 @@ object ArrowWriter { case (LongType, vector: BigIntVector) => new LongWriter(vector) case (FloatType, vector: Float4Vector) => new FloatWriter(vector) case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) @@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } } +private[arrow] class DecimalWriter( + val valueVector: DecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val decimal = input.getDecimal(ordinal, precision, scale) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } + } +} + private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index fd5a3df6abc68..261df06100aef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + test("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | } ] + | }, + | "batches" : [ { + | "count" : 7, + | "columns" : [ { + | "name" : "a_d", + | "count" : 7, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | "1000000000000000000", + | "2000000000000000000", + | "10000000000000000", + | "200000000000000000000", + | "100000000000000", + | "20000000000000000000000", + | "30000000000000000000" ] + | }, { + | "name" : "b_d", + | "count" : 7, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ], + | "DATA" : [ + | "1100000000000000000", + | "0", + | "0", + | "2200000000000000000", + | "0", + | "3300000000000000000", + | "0" ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + Some(Decimal("123456789012345678901234567890"))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index a71e30aa3ca96..508c116aae92e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLong(rowId) case FloatType => reader.getFloat(rowId) case DoubleType => reader.getDouble(rowId) + case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) @@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, Seq(1L, 2L, null, 4L)) check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) check(DateType, Seq(0, 1, 2, null, 4)) From ff48b1b338241039a7189e7a3c04333b1256fdb3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 26 Dec 2017 06:39:40 -0800 Subject: [PATCH 501/712] [SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF ## What changes were proposed in this pull request? In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not available for python UDF. This flag is useful for cases when the UDF's code can return different result with the same input. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. This can lead to unexpected behavior. This PR adds the deterministic flag, via the `asNondeterministic` method, to let the user mark the function as non-deterministic and therefore avoid the optimizations which might lead to strange behaviors. ## How was this patch tested? Manual tests: ``` >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> df_br = spark.createDataFrame([{'name': 'hello'}]) >>> import random >>> udf_random_col = udf(lambda: int(100*random.random()), IntegerType()).asNondeterministic() >>> df_br = df_br.withColumn('RAND', udf_random_col()) >>> random.seed(1234) >>> udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) >>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show() +-----+----+-------------+ | name|RAND|RAND_PLUS_TEN| +-----+----+-------------+ |hello| 3| 13| +-----+----+-------------+ ``` Author: Marco Gaido Author: Marco Gaido Closes #19929 from mgaido91/SPARK-22629. --- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/sql/functions.py | 11 ++++++++--- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/sql/udf.py | 13 ++++++++++++- .../org/apache/spark/sql/UDFRegistration.scala | 5 +++-- .../spark/sql/execution/python/PythonUDF.scala | 5 ++++- .../python/UserDefinedPythonFunction.scala | 5 +++-- .../execution/python/BatchEvalPythonExecSuite.scala | 3 ++- 8 files changed, 48 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 93d508c28ebba..1ec0e717fac29 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,6 +39,13 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + + def toString(pythonEvalType: Int): String = pythonEvalType match { + case NON_UDF => "NON_UDF" + case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" + case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" + case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + } } /** diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ddd8df3b15bf6..66ee033bab998 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2093,9 +2093,14 @@ class PandasUDFType(object): def udf(f=None, returnType=StringType()): """Creates a user defined function (UDF). - .. note:: The user-defined functions must be deterministic. Due to optimization, - duplicate invocations may be eliminated or the function may even be invoked more times than - it is present in the query. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> from pyspark.sql.types import IntegerType + >>> import random + >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b811a0f8a31d3..3ef1522887325 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -435,6 +435,15 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 123138117fdc3..54b5a8656e1c8 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -92,6 +92,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType + self._deterministic = True @property def returnType(self): @@ -129,7 +130,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType) + self._name, wrapped_func, jdt, self.evalType, self._deterministic) return judf def __call__(self, *cols): @@ -161,5 +162,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType + wrapper.asNondeterministic = self.asNondeterministic return wrapper + + def asNondeterministic(self): + """ + Updates UserDefinedFunction to nondeterministic. + + .. versionadded:: 2.3 + """ + self._deterministic = False + return self diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3ff476147b8b7..dc2468a721e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.Try import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} @@ -41,8 +42,6 @@ import org.apache.spark.util.Utils * spark.udf * }}} * - * @note The user-defined functions must be deterministic. - * * @since 1.3.0 */ @InterfaceStability.Stable @@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | pythonIncludes: ${udf.func.pythonIncludes} | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} + | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)} + | udfDeterministic: ${udf.udfDeterministic} """.stripMargin) functionRegistry.createOrReplaceTempFunction(name, udf.builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index ef27fbc2db7d9..d3f743d9eb61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,9 +29,12 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - evalType: Int) + evalType: Int, + udfDeterministic: Boolean) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override def toString: String = s"$name(${children.mkString(", ")})" override def nullable: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 348e49e473ed3..50dca32cb7861 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -29,10 +29,11 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonEvalType: Int) { + pythonEvalType: Int, + udfDeterministic: Boolean) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonEvalType) + PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 53d3f34567518..9e4a2e8776956 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonEvalType = PythonEvalType.SQL_BATCHED_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) From 9348e684208465a8f75c893bdeaa30fc42c0cb5f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Dec 2017 09:37:39 -0800 Subject: [PATCH 502/712] [SPARK-22833][EXAMPLE] Improvement SparkHive Scala Examples ## What changes were proposed in this pull request? Some improvements: 1. Point out we are using both Spark SQ native syntax and HQL syntax in the example 2. Avoid using the same table name with temp view, to not confuse users. 3. Create the external hive table with a directory that already has data, which is a more common use case. 4. Remove the usage of `spark.sql.parquet.writeLegacyFormat`. This config was introduced by https://github.com/apache/spark/pull/8566 and has nothing to do with Hive. 5. Remove `repartition` and `coalesce` example. These 2 are not Hive specific, we should put them in a different example file. BTW they can't accurately control the number of output files, `spark.sql.files.maxRecordsPerFile` also controls it. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20081 from cloud-fan/minor. --- .../examples/sql/hive/SparkHiveExample.scala | 75 +++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 4 +- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index b193bd595127c..70fb5b27366ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -102,40 +102,53 @@ object SparkHiveExample { // | 5| val_5| 5| val_5| // ... - // Create Hive managed table with Parquet - sql("CREATE TABLE records(key int, value string) STORED AS PARQUET") - // Save DataFrame to Hive managed table as Parquet format - val hiveTableDF = sql("SELECT * FROM records") - hiveTableDF.write.mode(SaveMode.Overwrite).saveAsTable("database_name.records") - // Create External Hive table with Parquet - sql("CREATE EXTERNAL TABLE records(key int, value string) " + - "STORED AS PARQUET LOCATION '/user/hive/warehouse/'") - // to make Hive Parquet format compatible with Spark Parquet format - spark.sqlContext.setConf("spark.sql.parquet.writeLegacyFormat", "true") - - // Multiple Parquet files could be created accordingly to volume of data under directory given. - val hiveExternalTableLocation = "/user/hive/warehouse/database_name.db/records" - - // Save DataFrame to Hive External table as compatible Parquet format - hiveTableDF.write.mode(SaveMode.Overwrite).parquet(hiveExternalTableLocation) - - // Turn on flag for Dynamic Partitioning - spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") - spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") - - // You can create partitions in Hive table, so downstream queries run much faster. - hiveTableDF.write.mode(SaveMode.Overwrite).partitionBy("key") - .parquet(hiveExternalTableLocation) + // Create a Hive managed Parquet table, with HQL syntax instead of the Spark SQL native syntax + // `USING hive` + sql("CREATE TABLE hive_records(key int, value string) STORED AS PARQUET") + // Save DataFrame to the Hive managed table + val df = spark.table("src") + df.write.mode(SaveMode.Overwrite).saveAsTable("hive_records") + // After insertion, the Hive managed table has data now + sql("SELECT * FROM hive_records").show() + // +---+-------+ + // |key| value| + // +---+-------+ + // |238|val_238| + // | 86| val_86| + // |311|val_311| + // ... - // Reduce number of files for each partition by repartition - hiveTableDF.repartition($"key").write.mode(SaveMode.Overwrite) - .partitionBy("key").parquet(hiveExternalTableLocation) + // Prepare a Parquet data directory + val dataDir = "/tmp/parquet_data" + spark.range(10).write.parquet(dataDir) + // Create a Hive external Parquet table + sql(s"CREATE EXTERNAL TABLE hive_ints(key int) STORED AS PARQUET LOCATION '$dataDir'") + // The Hive external table should already have data + sql("SELECT * FROM hive_ints").show() + // +---+ + // |key| + // +---+ + // | 0| + // | 1| + // | 2| + // ... - // Control the number of files in each partition by coalesce - hiveTableDF.coalesce(10).write.mode(SaveMode.Overwrite) - .partitionBy("key").parquet(hiveExternalTableLocation) - // $example off:spark_hive$ + // Turn on flag for Hive Dynamic Partitioning + spark.sqlContext.setConf("hive.exec.dynamic.partition", "true") + spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict") + // Create a Hive partitioned table using DataFrame API + df.write.partitionBy("key").format("hive").saveAsTable("hive_part_tbl") + // Partitioned column `key` will be moved to the end of the schema. + sql("SELECT * FROM hive_part_tbl").show() + // +-------+---+ + // | value|key| + // +-------+---+ + // |val_238|238| + // | val_86| 86| + // |val_311|311| + // ... spark.stop() + // $example off:spark_hive$ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84fe4bb711a4e..f16972e5427e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -336,8 +336,8 @@ object SQLConf { .createWithDefault(true) val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") - .doc("Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.") + .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") .booleanConf .createWithDefault(false) From 91d1b300d467cc91948f71e87b7afe1b9027e60f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 Dec 2017 09:40:41 -0800 Subject: [PATCH 503/712] [SPARK-22894][SQL] DateTimeOperations should accept SQL like string type ## What changes were proposed in this pull request? `DateTimeOperations` accept [`StringType`](https://github.com/apache/spark/blob/ae998ec2b5548b7028d741da4813473dde1ad81e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala#L669), but: ``` spark-sql> SELECT '2017-12-24' + interval 2 months 2 seconds; Error in query: cannot resolve '(CAST('2017-12-24' AS DOUBLE) + interval 2 months 2 seconds)' due to data type mismatch: differing types in '(CAST('2017-12-24' AS DOUBLE) + interval 2 months 2 seconds)' (double and calendarinterval).; line 1 pos 7; 'Project [unresolvedalias((cast(2017-12-24 as double) + interval 2 months 2 seconds), None)] +- OneRowRelation spark-sql> ``` After this PR: ``` spark-sql> SELECT '2017-12-24' + interval 2 months 2 seconds; 2018-02-24 00:00:02 Time taken: 0.2 seconds, Fetched 1 row(s) ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #20067 from wangyum/SPARK-22894. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 6 ++++-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 ++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2f306f58b7b80..1c4be547739dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -324,9 +324,11 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right) => + case a @ BinaryArithmetic(left @ StringType(), right) + if right.dataType != CalendarIntervalType => a.makeCopy(Array(Cast(left, DoubleType), right)) - case a @ BinaryArithmetic(left, right @ StringType()) => + case a @ BinaryArithmetic(left, right @ StringType()) + if left.dataType != CalendarIntervalType => a.makeCopy(Array(left, Cast(right, DoubleType))) // For equality between string and timestamp we cast the string to a timestamp diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..1972dec7b86ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2760,6 +2760,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-22894: DateTimeOperations should accept SQL like string type") { + val date = "2017-12-24" + val str = sql(s"SELECT CAST('$date' as STRING) + interval 2 months 2 seconds") + val dt = sql(s"SELECT CAST('$date' as DATE) + interval 2 months 2 seconds") + val ts = sql(s"SELECT CAST('$date' as TIMESTAMP) + interval 2 months 2 seconds") + + checkAnswer(str, Row("2018-02-24 00:00:02") :: Nil) + checkAnswer(dt, Row(Date.valueOf("2018-02-24")) :: Nil) + checkAnswer(ts, Row(Timestamp.valueOf("2018-02-24 00:00:02")) :: Nil) + } + // Only New OrcFileFormat supports this Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, "parquet").foreach { format => From 6674acd1edc8c657049113df95af3378b5db8806 Mon Sep 17 00:00:00 2001 From: "xu.wenchun" Date: Wed, 27 Dec 2017 10:08:32 +0800 Subject: [PATCH 504/712] [SPARK-22846][SQL] Fix table owner is null when creating table through spark sql or thriftserver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? fix table owner is null when create new table through spark sql ## How was this patch tested? manual test. 1、first create a table 2、then select the table properties from mysql which connected to hive metastore Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xu.wenchun Closes #20034 from BruceXu1991/SPARK-22846. --- .../scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7233944dc96dd..7b7f4e0f10210 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -186,7 +186,7 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf - private val userName = state.getAuthenticator.getUserName + private val userName = conf.getUser override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) From b8bfce51abf28c66ba1fc67b0f25fe1617c81025 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Dec 2017 20:51:26 +0900 Subject: [PATCH 505/712] [SPARK-22324][SQL][PYTHON][FOLLOW-UP] Update setup.py file. ## What changes were proposed in this pull request? This is a follow-up pr of #19884 updating setup.py file to add pyarrow dependency. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20089 from ueshin/issues/SPARK-22324/fup1. --- python/README.md | 2 +- python/setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/README.md b/python/README.md index 84ec88141cb00..3f17fdb98a081 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). +At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/setup.py b/python/setup.py index 310670e697a83..251d4526d4dd0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2'] + 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -210,6 +210,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) From 774715d5c73ab6d410208fa1675cd11166e03165 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 27 Dec 2017 23:53:10 +0800 Subject: [PATCH 506/712] [SPARK-22904][SQL] Add tests for decimal operations and string casts ## What changes were proposed in this pull request? Test coverage for arithmetic operations leading to: 1. Precision loss 2. Overflow Moreover, tests for casting bad string to other input types and for using bad string as operators of some functions. ## How was this patch tested? added tests Author: Marco Gaido Closes #20084 from mgaido91/SPARK-22904. --- .../native/decimalArithmeticOperations.sql | 33 +++ .../native/stringCastAndExpressions.sql | 57 ++++ .../decimalArithmeticOperations.sql.out | 82 ++++++ .../native/stringCastAndExpressions.sql.out | 261 ++++++++++++++++++ 4 files changed, 433 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql new file mode 100644 index 0000000000000..c8e108ac2c45e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -0,0 +1,33 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b; + +-- division, remainder and pmod by 0 return NULL +select a / b from t; +select a % b from t; +select pmod(a, b) from t; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss return NULL +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql new file mode 100644 index 0000000000000..f17adb56dee91 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/stringCastAndExpressions.sql @@ -0,0 +1,57 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a; + +-- casting to data types which are unable to represent the string input returns NULL +select cast(a as byte) from t; +select cast(a as short) from t; +select cast(a as int) from t; +select cast(a as long) from t; +select cast(a as float) from t; +select cast(a as double) from t; +select cast(a as decimal) from t; +select cast(a as boolean) from t; +select cast(a as timestamp) from t; +select cast(a as date) from t; +-- casting to binary works correctly +select cast(a as binary) from t; +-- casting to array, struct or map throws exception +select cast(a as array) from t; +select cast(a as struct) from t; +select cast(a as map) from t; + +-- all timestamp/date expressions return NULL if bad input strings are provided +select to_timestamp(a) from t; +select to_timestamp('2018-01-01', a) from t; +select to_unix_timestamp(a) from t; +select to_unix_timestamp('2018-01-01', a) from t; +select unix_timestamp(a) from t; +select unix_timestamp('2018-01-01', a) from t; +select from_unixtime(a) from t; +select from_unixtime('2018-01-01', a) from t; +select next_day(a, 'MO') from t; +select next_day('2018-01-01', a) from t; +select trunc(a, 'MM') from t; +select trunc('2018-01-01', a) from t; + +-- some functions return NULL if bad input is provided +select unhex('-123'); +select sha2(a, a) from t; +select get_json_object(a, a) from t; +select json_tuple(a, a) from t; +select from_json(a, 'a INT') from t; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out new file mode 100644 index 0000000000000..ce02f6adc456c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1.0 as a, 0.0 as b +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select a / b from t +-- !query 1 schema +struct<(CAST(a AS DECIMAL(2,1)) / CAST(b AS DECIMAL(2,1))):decimal(8,6)> +-- !query 1 output +NULL + + +-- !query 2 +select a % b from t +-- !query 2 schema +struct<(CAST(a AS DECIMAL(2,1)) % CAST(b AS DECIMAL(2,1))):decimal(1,1)> +-- !query 2 output +NULL + + +-- !query 3 +select pmod(a, b) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select (5e36 + 0.1) + 5e36 +-- !query 4 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 4 output +NULL + + +-- !query 5 +select (-4e36 - 0.1) - 7e36 +-- !query 5 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 5 output +NULL + + +-- !query 6 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 6 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 6 output +NULL + + +-- !query 7 +select 1e35 / 0.1 +-- !query 7 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +-- !query 7 output +NULL + + +-- !query 8 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 8 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 8 output +NULL + + +-- !query 9 +select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 9 schema +struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +-- !query 9 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out new file mode 100644 index 0000000000000..8ed2820244412 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -0,0 +1,261 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 32 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 'aa' as a +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(a as byte) from t +-- !query 1 schema +struct +-- !query 1 output +NULL + + +-- !query 2 +select cast(a as short) from t +-- !query 2 schema +struct +-- !query 2 output +NULL + + +-- !query 3 +select cast(a as int) from t +-- !query 3 schema +struct +-- !query 3 output +NULL + + +-- !query 4 +select cast(a as long) from t +-- !query 4 schema +struct +-- !query 4 output +NULL + + +-- !query 5 +select cast(a as float) from t +-- !query 5 schema +struct +-- !query 5 output +NULL + + +-- !query 6 +select cast(a as double) from t +-- !query 6 schema +struct +-- !query 6 output +NULL + + +-- !query 7 +select cast(a as decimal) from t +-- !query 7 schema +struct +-- !query 7 output +NULL + + +-- !query 8 +select cast(a as boolean) from t +-- !query 8 schema +struct +-- !query 8 output +NULL + + +-- !query 9 +select cast(a as timestamp) from t +-- !query 9 schema +struct +-- !query 9 output +NULL + + +-- !query 10 +select cast(a as date) from t +-- !query 10 schema +struct +-- !query 10 output +NULL + + +-- !query 11 +select cast(a as binary) from t +-- !query 11 schema +struct +-- !query 11 output +aa + + +-- !query 12 +select cast(a as array) from t +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to array; line 1 pos 7 + + +-- !query 13 +select cast(a as struct) from t +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to struct; line 1 pos 7 + + +-- !query 14 +select cast(a as map) from t +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`a`' due to data type mismatch: cannot cast string to map; line 1 pos 7 + + +-- !query 15 +select to_timestamp(a) from t +-- !query 15 schema +struct +-- !query 15 output +NULL + + +-- !query 16 +select to_timestamp('2018-01-01', a) from t +-- !query 16 schema +struct +-- !query 16 output +NULL + + +-- !query 17 +select to_unix_timestamp(a) from t +-- !query 17 schema +struct +-- !query 17 output +NULL + + +-- !query 18 +select to_unix_timestamp('2018-01-01', a) from t +-- !query 18 schema +struct +-- !query 18 output +NULL + + +-- !query 19 +select unix_timestamp(a) from t +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select unix_timestamp('2018-01-01', a) from t +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select from_unixtime(a) from t +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select from_unixtime('2018-01-01', a) from t +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select next_day(a, 'MO') from t +-- !query 23 schema +struct +-- !query 23 output +NULL + + +-- !query 24 +select next_day('2018-01-01', a) from t +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select trunc(a, 'MM') from t +-- !query 25 schema +struct +-- !query 25 output +NULL + + +-- !query 26 +select trunc('2018-01-01', a) from t +-- !query 26 schema +struct +-- !query 26 output +NULL + + +-- !query 27 +select unhex('-123') +-- !query 27 schema +struct +-- !query 27 output +NULL + + +-- !query 28 +select sha2(a, a) from t +-- !query 28 schema +struct +-- !query 28 output +NULL + + +-- !query 29 +select get_json_object(a, a) from t +-- !query 29 schema +struct +-- !query 29 output +NULL + + +-- !query 30 +select json_tuple(a, a) from t +-- !query 30 schema +struct +-- !query 30 output +NULL + + +-- !query 31 +select from_json(a, 'a INT') from t +-- !query 31 schema +struct> +-- !query 31 output +NULL From 753793bc84df805e519cf59f6804ab26bd02d76e Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 27 Dec 2017 17:31:12 -0800 Subject: [PATCH 507/712] [SPARK-22899][ML][STREAMING] Fix OneVsRestModel transform on streaming data failed. ## What changes were proposed in this pull request? Fix OneVsRestModel transform on streaming data failed. ## How was this patch tested? UT will be added soon, once #19979 merged. (Need a helper test method there) Author: WeichenXu Closes #20077 from WeichenXu123/fix_ovs_model_transform. --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 3ab99b35ece2b..f04fde2cbbca1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -165,7 +165,7 @@ final class OneVsRestModel private[ml] ( val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. - val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val handlePersistence = !dataset.isStreaming && dataset.storageLevel == StorageLevel.NONE if (handlePersistence) { newDataset.persist(StorageLevel.MEMORY_AND_DISK) } From 5683984520cfe9e9acf49e47a84a56af155a8ad2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 28 Dec 2017 12:28:19 +0800 Subject: [PATCH 508/712] [SPARK-18016][SQL][FOLLOW-UP] Code Generation: Constant Pool Limit - reduce entries for mutable state ## What changes were proposed in this pull request? This PR addresses additional review comments in #19811 ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #20036 from kiszk/SPARK-18066-followup. --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../expressions/regexpExpressions.scala | 62 +++++++++---------- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +++-- .../execution/basicPhysicalOperators.scala | 4 +- .../joins/BroadcastHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 7 ++- 8 files changed, 51 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d6eccadcfb63e..2c714c228e6c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -190,7 +190,7 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { @@ -352,7 +352,7 @@ class CodegenContext { def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStateInitCode.distinct + val initCodes = mutableStateInitCode.distinct.map(_ + "\n") // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index fa5425c77ebba..f3e8f6de58975 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - // inline mutable state since not many Like operations in a task val pattern = ctx.addMutableState(patternClass, "patternLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr)); - ${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($escapeFunc($rightStr)); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - // inline mutable state since not many RLike operations in a task val pattern = ctx.addMutableState(patternClass, "patternRLike", - v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true) + v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) @@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($rightStr); - ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($rightStr); + ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) } @@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - if (!$rep.equals(${termLastReplacementInUTF8})) { + if (!$rep.equals($termLastReplacementInUTF8)) { // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); } - $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); - java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); + $classNameStringBuffer $termResult = new $classNameStringBuffer(); + java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); - while (${matcher}.find()) { - ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); + while ($matcher.find()) { + $matcher.appendReplacement($termResult, $termLastReplacement); } - ${matcher}.appendTail(${termResult}); - ${ev.value} = UTF8String.fromString(${termResult}.toString()); - ${termResult} = null; + $matcher.appendTail($termResult); + ${ev.value} = UTF8String.fromString($termResult.toString()); + $termResult = null; $setEvNotNull """ }) @@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals($termLastRegex)) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); } - java.util.regex.Matcher ${matcher} = - ${termPattern}.matcher($subject.toString()); - if (${matcher}.find()) { - java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); - if (${matchResult}.group($idx) == null) { + java.util.regex.Matcher $matcher = + $termPattern.matcher($subject.toString()); + if ($matcher.find()) { + java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + if ($matchResult.group($idx) == null) { ${ev.value} = UTF8String.EMPTY_UTF8; } else { - ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + ${ev.value} = UTF8String.fromString($matchResult.group($idx)); } $setEvNotNull } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index daff3c49e7517..ef1bb1c2a4468 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -138,7 +138,7 @@ case class SortExec( // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) - // inline mutable state since not many Sort operations in a task + // Inline mutable state since not many Sort operations in a task sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter", v => s"$v = $thisPlan.createSorter();", forceInline = true) val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9e7008d1e0c31..065954559e487 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { // Right now, InputAdapter is only used when there is one input RDD. - // inline mutable state since an inputAdaptor in a task + // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", forceInline = true) val row = ctx.freshName("row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1af360d85095..9a6f1c6dfa6a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -587,31 +587,35 @@ case class HashAggregateExec( fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap", - v => s"$v = new $fastHashMapClassName();") - ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter") + v => s"$v = new $fastHashMapClassName();", forceInline = true) + ctx.addMutableState(s"java.util.Iterator", "vectorizedFastHashMapIter", + forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, fastHashMapClassName, groupingKeySchema, bufferSchema).generate() ctx.addInnerClass(generatedMap) + // Inline mutable state since not many aggregation operations in a task fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap", v => s"$v = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());", + forceInline = true) ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - "fastHashMapIter") + "fastHashMapIter", forceInline = true) } } // Create a name for the iterator from the regular hash map. - // inline mutable state since not many aggregation operations in a task + // Inline mutable state since not many aggregation operations in a task val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, "mapIter", forceInline = true) // create hashMap val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", - v => s"$v = $thisPlan.createHashMap();") + v => s"$v = $thisPlan.createHashMap();", forceInline = true) sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter", forceInline = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 78137d3f97cfc..a15a8d11aa2a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -284,7 +284,7 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - // inline mutable state since not many Sample operations in a task + // Inline mutable state since not many Sample operations in a task val sampler = ctx.addMutableState(s"$samplerClass", "sampleReplace", v => { val initSamplerFuncName = ctx.addNewFunction(initSampler, @@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - // inline mutable state since not many Range operations in a task + // Inline mutable state since not many Range operations in a task val taskContext = ctx.addMutableState("TaskContext", "taskContext", v => s"$v = TaskContext.get();", forceInline = true) val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ee763e23415cf..1918fcc5482db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -139,7 +139,7 @@ case class BroadcastHashJoinExec( // At the end of the task, we update the avg hash probe. val avgHashProbe = metricTerm(ctx, "avgHashProbe") - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s""" | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 073730462a75f..94405410cce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -422,7 +422,7 @@ case class SortMergeJoinExec( */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) @@ -440,8 +440,9 @@ case class SortMergeJoinExec( val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold + // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);") + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -576,7 +577,7 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { - // inline mutable state since not many join operations in a task + // Inline mutable state since not many join operations in a task val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", v => s"$v = inputs[0];", forceInline = true) val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", From 32ec269d08313720aae3b47cce2f5e9c19702811 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 28 Dec 2017 12:35:17 +0800 Subject: [PATCH 509/712] [SPARK-22909][SS] Move Structured Streaming v2 APIs to streaming folder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR moves Structured Streaming v2 APIs to streaming folder as following: ``` sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming ├── ContinuousReadSupport.java ├── ContinuousWriteSupport.java ├── MicroBatchReadSupport.java ├── MicroBatchWriteSupport.java ├── reader │   ├── ContinuousDataReader.java │   ├── ContinuousReader.java │   ├── MicroBatchReader.java │   ├── Offset.java │   └── PartitionOffset.java └── writer └── ContinuousWriter.java ``` ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20093 from zsxwing/move. --- .../sources/v2/{ => streaming}/ContinuousReadSupport.java | 7 ++++--- .../sources/v2/{ => streaming}/ContinuousWriteSupport.java | 6 ++++-- .../sources/v2/{ => streaming}/MicroBatchReadSupport.java | 6 ++++-- .../sources/v2/{ => streaming}/MicroBatchWriteSupport.java | 4 +++- .../v2/{ => streaming}/reader/ContinuousDataReader.java | 6 ++---- .../v2/{ => streaming}/reader/ContinuousReader.java | 4 ++-- .../v2/{ => streaming}/reader/MicroBatchReader.java | 4 ++-- .../sql/sources/v2/{ => streaming}/reader/Offset.java | 2 +- .../sources/v2/{ => streaming}/reader/PartitionOffset.java | 2 +- .../v2/{ => streaming}/writer/ContinuousWriter.java | 5 ++++- .../execution/datasources/v2/DataSourceV2ScanExec.scala | 1 + .../sql/execution/datasources/v2/WriteToDataSourceV2.scala | 1 + .../sql/execution/streaming/MicroBatchExecution.scala | 2 +- .../spark/sql/execution/streaming/RateSourceProvider.scala | 3 ++- .../spark/sql/execution/streaming/RateStreamOffset.scala | 2 +- .../spark/sql/execution/streaming/StreamingRelation.scala | 3 ++- .../streaming/continuous/ContinuousDataSourceRDDIter.scala | 1 + .../streaming/continuous/ContinuousExecution.scala | 7 ++++--- .../streaming/continuous/ContinuousRateStreamSource.scala | 3 ++- .../execution/streaming/continuous/EpochCoordinator.scala | 5 +++-- .../execution/streaming/sources/RateStreamSourceV2.scala | 1 + .../spark/sql/execution/streaming/sources/memoryV2.scala | 4 +++- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 3 ++- .../apache/spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeqLogSuite.scala | 1 - .../spark/sql/execution/streaming/RateSourceSuite.scala | 1 - .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 3 ++- 27 files changed, 54 insertions(+), 35 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/ContinuousReadSupport.java (88%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/ContinuousWriteSupport.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/MicroBatchReadSupport.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/MicroBatchWriteSupport.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/ContinuousDataReader.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/ContinuousReader.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/MicroBatchReader.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/Offset.java (97%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/reader/PartitionOffset.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ => streaming}/writer/ContinuousWriter.java (87%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java similarity index 88% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index ae4f85820649f..8837bae6156b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; -import org.apache.spark.sql.sources.v2.reader.ContinuousReader; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java index 362d5f52b4d00..ec15e436672b4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 442cad029d211..3c87a3db68243 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java index 63640779b955c..b5e3e44cd6332 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.streaming; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index 11b99a93f1494..ca9a290e97a02 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -15,11 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; - -import java.io.IOException; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 34141d6cd85fd..f0b205869ed6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.PartitionOffset; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index bd15c07d87f6c..70ff756806032 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; -import org.apache.spark.sql.sources.v2.reader.Offset; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index ce1c489742054..517fdab16684c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; /** * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index 07826b6688476..729a6123034f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.streaming.reader; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java similarity index 87% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java index 618f47ed79ca5..723395bd1e963 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java @@ -15,9 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** * A {@link DataSourceV2Writer} for use with continuous stream processing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index e4fca1b10dfad..49c506bc560cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 1862da8892cb2..f0bdf84bb7a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 20f9810faa5c8..9a7a13fcc5806 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 3f85fa913f28c..d02cf882b61ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamR import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 65d6d18936167..261d69bbd9843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.reader.Offset { + extends v2.streaming.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 0ca2e7854d94b..a9d50e3a112e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 89fb2ace20917..d79e4bd65f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, Ro import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{SystemClock, ThreadUtils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1c35b06bd4b85..2843ab13bde2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,9 +29,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, ContinuousWriteSupport, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 89a8562b4b59e..c9aa78a5a2e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} case class ContinuousRateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 7f1e8abd79b99..98017c3ac6a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -26,8 +26,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1c66aed8690a7..97bada08bcd2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.SystemClock diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 972248d5e4df8..da7c31cf62428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f17935e86f459..acd5ca17dcf76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index e808ffaa96410..b508f4406138f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 4868ba4e68934..e6cdc063c4e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index ceba27b26e578..03d0f63fa4d7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.reader.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index dc833b2ccaa22..e11705a227f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport import org.apache.spark.sql.streaming.StreamTest class RateSourceV2Suite extends StreamTest { From 171f6ddadc6185ffcc6ad82e5f48952fb49095b2 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 28 Dec 2017 13:44:44 +0900 Subject: [PATCH 510/712] [SPARK-22757][KUBERNETES] Enable use of remote dependencies (http, s3, gcs, etc.) in Kubernetes mode ## What changes were proposed in this pull request? This PR expands the Kubernetes mode to be able to use remote dependencies on http/https endpoints, GCS, S3, etc. It adds steps for configuring and appending the Kubernetes init-container into the driver and executor pods for downloading remote dependencies. [Init-containers](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/), as the name suggests, are containers that are run to completion before the main containers start, and are often used to perform initialization tasks prior to starting the main containers. We use init-containers to localize remote application dependencies before the driver/executors start running. The code that the init-container runs is also included. This PR also adds a step to the driver and executors for mounting user-specified secrets that may store credentials for accessing data storage, e.g., S3 and Google Cloud Storage (GCS), into the driver and executors. ## How was this patch tested? * The patch contains unit tests which are passing. * Manual testing: `./build/mvn -Pkubernetes clean package` succeeded. * Manual testing of the following cases: * [x] Running SparkPi using container-local spark-example jar. * [x] Running SparkPi using container-local spark-example jar with user-specific secret mounted. * [x] Running SparkPi using spark-example jar hosted remotely on an https endpoint. cc rxin felixcheung mateiz (shepherd) k8s-big-data SIG members & contributors: mccheah foxish ash211 ssuchter varunkatta kimoonkim erikerlandson tnachen ifilonenko liyinan926 reviewers: vanzin felixcheung jiangxb1987 mridulm Author: Yinan Li Closes #19954 from liyinan926/init-container. --- .../org/apache/spark/deploy/k8s/Config.scala | 67 +++++++- .../apache/spark/deploy/k8s/Constants.scala | 11 ++ .../deploy/k8s/InitContainerBootstrap.scala | 119 +++++++++++++ ...sFileUtils.scala => KubernetesUtils.scala} | 50 +++++- .../deploy/k8s/MountSecretsBootstrap.scala | 62 +++++++ ...ala => PodWithDetachedInitContainer.scala} | 34 ++-- .../k8s/SparkKubernetesClientFactory.scala | 2 +- ...r.scala => DriverConfigOrchestrator.scala} | 91 +++++++--- .../submit/KubernetesClientApplication.scala | 26 ++- ...ala => BasicDriverConfigurationStep.scala} | 63 ++++--- .../steps/DependencyResolutionStep.scala | 27 +-- .../steps/DriverConfigurationStep.scala | 2 +- .../DriverInitContainerBootstrapStep.scala | 95 +++++++++++ .../submit/steps/DriverMountSecretsStep.scala | 38 +++++ .../steps/DriverServiceBootstrapStep.scala | 17 +- .../BasicInitContainerConfigurationStep.scala | 67 ++++++++ .../InitContainerConfigOrchestrator.scala | 79 +++++++++ .../InitContainerConfigurationStep.scala | 25 +++ .../InitContainerMountSecretsStep.scala | 39 +++++ .../initcontainer/InitContainerSpec.scala | 37 ++++ .../rest/k8s/SparkPodInitContainer.scala | 116 +++++++++++++ .../cluster/k8s/ExecutorPodFactory.scala | 77 ++++++--- .../k8s/KubernetesClusterManager.scala | 62 ++++++- ...la => DriverConfigOrchestratorSuite.scala} | 69 ++++++-- .../deploy/k8s/submit/SecretVolumeUtils.scala | 36 ++++ ...> BasicDriverConfigurationStepSuite.scala} | 4 +- ...riverInitContainerBootstrapStepSuite.scala | 160 ++++++++++++++++++ .../steps/DriverMountSecretsStepSuite.scala | 49 ++++++ ...cInitContainerConfigurationStepSuite.scala | 95 +++++++++++ ...InitContainerConfigOrchestratorSuite.scala | 80 +++++++++ .../InitContainerMountSecretsStepSuite.scala | 57 +++++++ .../rest/k8s/SparkPodInitContainerSuite.scala | 86 ++++++++++ .../cluster/k8s/ExecutorPodFactorySuite.scala | 82 ++++++++- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 24 +++ 36 files changed, 1783 insertions(+), 171 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/KubernetesFileUtils.scala => KubernetesUtils.scala} (57%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{ConfigurationUtils.scala => PodWithDetachedInitContainer.scala} (53%) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/{DriverConfigurationStepsOrchestrator.scala => DriverConfigOrchestrator.scala} (53%) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/{BaseDriverConfigurationStep.scala => BasicDriverConfigurationStep.scala} (70%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/{DriverConfigurationStepsOrchestratorSuite.scala => DriverConfigOrchestratorSuite.scala} (51%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/{BaseDriverConfigurationStepSuite.scala => BasicDriverConfigurationStepSuite.scala} (97%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 45f527959cbe1..e5d79d9a9d9da 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -20,7 +20,6 @@ import java.util.concurrent.TimeUnit import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder -import org.apache.spark.network.util.ByteUnit private[spark] object Config extends Logging { @@ -132,20 +131,72 @@ private[spark] object Config extends Logging { val JARS_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pod.") + .doc("Location to download jars to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pod.") .stringConf .createWithDefault("/var/spark-data/spark-jars") val FILES_DOWNLOAD_LOCATION = ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using" + - " spark-submit, this directory must be empty and will be mounted as an empty directory" + - " volume on the driver and executor pods.") + .doc("Location to download files to in the driver and executors. When using " + + "spark-submit, this directory must be empty and will be mounted as an empty directory " + + "volume on the driver and executor pods.") .stringConf .createWithDefault("/var/spark-data/spark-files") + val INIT_CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.initContainer.image") + .doc("Image for the driver and executor's init-container for downloading dependencies.") + .stringConf + .createOptional + + val INIT_CONTAINER_MOUNT_TIMEOUT = + ConfigBuilder("spark.kubernetes.mountDependencies.timeout") + .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + + "locations into the driver and executor pods.") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(300) + + val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = + ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") + .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + + "executor pod.") + .intConf + .createWithDefault(5) + + val INIT_CONTAINER_REMOTE_JARS = + ConfigBuilder("spark.kubernetes.initContainer.remoteJars") + .doc("Comma-separated list of jar URIs to download in the init-container. This is " + + "calculated from spark.jars.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_REMOTE_FILES = + ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") + .doc("Comma-separated list of file URIs to download in the init-container. This is " + + "calculated from spark.files.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_NAME = + ConfigBuilder("spark.kubernetes.initContainer.configMapName") + .doc("Name of the config map to use in the init-container that retrieves submitted files " + + "for the executor.") + .internal() + .stringConf + .createOptional + + val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = + ConfigBuilder("spark.kubernetes.initContainer.configMapKey") + .doc("Key for the entry in the init container config map for submitted files that " + + "corresponds to the properties for this init-container.") + .internal() + .stringConf + .createOptional + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -153,9 +204,11 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 0b91145405d3a..111cb2a3b75e5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -69,6 +69,17 @@ private[spark] object Constants { val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" + val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" + + // Bootstrapping dependencies with the init-container + val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" + val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" + val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" + val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" + val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" + val INIT_CONTAINER_PROPERTIES_FILE_PATH = + s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" + val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala new file mode 100644 index 0000000000000..dfeccf9e2bd1c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Bootstraps an init-container for downloading remote dependencies. This is separated out from + * the init-container steps API because this component can be used to bootstrap init-containers + * for both the driver and executors. + */ +private[spark] class InitContainerBootstrap( + initContainerImage: String, + imagePullPolicy: String, + jarsDownloadPath: String, + filesDownloadPath: String, + configMapName: String, + configMapKey: String, + sparkRole: String, + sparkConf: SparkConf) { + + /** + * Bootstraps an init-container that downloads dependencies to be used by a main container. + */ + def bootstrapInitContainer( + original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { + val sharedVolumeMounts = Seq[VolumeMount]( + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withMountPath(jarsDownloadPath) + .build(), + new VolumeMountBuilder() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withMountPath(filesDownloadPath) + .build()) + + val customEnvVarKeyPrefix = sparkRole match { + case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY + case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." + case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") + } + val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { + case (key, value) => + new EnvVarBuilder() + .withName(key) + .withValue(value) + .build() + } + + val initContainer = new ContainerBuilder(original.initContainer) + .withName("spark-init") + .withImage(initContainerImage) + .withImagePullPolicy(imagePullPolicy) + .addAllToEnv(customEnvVars.asJava) + .addNewVolumeMount() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) + .endVolumeMount() + .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) + .build() + + val podWithBasicVolumes = new PodBuilder(original.pod) + .editSpec() + .addNewVolume() + .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .addNewItem() + .withKey(configMapKey) + .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .addNewVolume() + .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) + .withEmptyDir(new EmptyDirVolumeSource()) + .endVolume() + .endSpec() + .build() + + val mainContainer = new ContainerBuilder(original.mainContainer) + .addToVolumeMounts(sharedVolumeMounts: _*) + .addNewEnv() + .withName(ENV_MOUNTED_FILES_DIR) + .withValue(filesDownloadPath) + .endEnv() + .build() + + PodWithDetachedInitContainer( + podWithBasicVolumes, + initContainer, + mainContainer) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala similarity index 57% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index a38cf55fc3d58..37331d8bbf9b7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesFileUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -14,13 +14,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import java.io.File +import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} + +import org.apache.spark.SparkConf import org.apache.spark.util.Utils -private[spark] object KubernetesFileUtils { +private[spark] object KubernetesUtils { + + /** + * Extract and parse Spark configuration properties with a given name prefix and + * return the result as a Map. Keys must not have more than one value. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing the configuration property keys and values + */ + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String): Map[String, String] = { + sparkConf.getAllWithPrefix(prefix).toMap + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } + + /** + * Append the given init-container to a pod's list of init-containers. + * + * @param originalPodSpec original specification of the pod + * @param initContainer the init-container to add to the pod + * @return the pod with the init-container added to the list of InitContainers + */ + def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { + new PodBuilder(originalPodSpec) + .editOrNewSpec() + .addToInitContainers(initContainer) + .endSpec() + .build() + } /** * For the given collection of file URIs, resolves them as follows: @@ -47,6 +83,16 @@ private[spark] object KubernetesFileUtils { } } + /** + * Get from a given collection of file URIs the ones that represent remote files. + */ + def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { + uris.filter { uri => + val scheme = Utils.resolveURI(uri).getScheme + scheme != "file" && scheme != "local" + } + } + private def resolveFileUri( uri: String, fileDownloadPath: String, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala new file mode 100644 index 0000000000000..8286546ce0641 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} + +/** + * Bootstraps a driver or executor container or an init-container with needed secrets mounted. + */ +private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { + + /** + * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * + * @param pod the pod into which the secret volumes are being added. + * @param container the container into which the secret volumes are being mounted. + * @return the updated pod and container with the secrets mounted. + */ + def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + var podBuilder = new PodBuilder(pod) + secretNamesToMountPaths.keys.foreach { name => + podBuilder = podBuilder + .editOrNewSpec() + .addNewVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() + .endSpec() + } + + var containerBuilder = new ContainerBuilder(container) + secretNamesToMountPaths.foreach { case (name, path) => + containerBuilder = containerBuilder + .addNewVolumeMount() + .withName(secretVolumeName(name)) + .withMountPath(path) + .endVolumeMount() + } + + (podBuilder.build(), containerBuilder.build()) + } + + private def secretVolumeName(secretName: String): String = { + secretName + "-volume" + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala similarity index 53% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala index 01717479fddd9..0b79f8b12e806 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/ConfigurationUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala @@ -14,28 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.deploy.k8s -import org.apache.spark.SparkConf - -private[spark] object ConfigurationUtils { - - /** - * Extract and parse Spark configuration properties with a given name prefix and - * return the result as a Map. Keys must not have more than one value. - * - * @param sparkConf Spark configuration - * @param prefix the given property name prefix - * @return a Map storing the configuration property keys and values - */ - def parsePrefixedKeyValuePairs( - sparkConf: SparkConf, - prefix: String): Map[String, String] = { - sparkConf.getAllWithPrefix(prefix).toMap - } +import io.fabric8.kubernetes.api.model.{Container, Pod} - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { - opt1.foreach { _ => require(opt2.isEmpty, errMessage) } - } -} +/** + * Represents a pod with a detached init-container (not yet added to the pod). + * + * @param pod the pod + * @param initContainer the init-container in the pod + * @param mainContainer the main container in the pod + */ +private[spark] case class PodWithDetachedInitContainer( + pod: Pod, + initContainer: Container, + mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 1e3f055e05766..c47e78cbf19e3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -48,7 +48,7 @@ private[spark] object SparkKubernetesClientFactory { .map(new File(_)) .orElse(defaultServiceAccountToken) val oauthTokenValue = sparkConf.getOption(oauthTokenConf) - ConfigurationUtils.requireNandDefined( + KubernetesUtils.requireNandDefined( oauthTokenFile, oauthTokenValue, s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a " + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala similarity index 53% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 1411e6f40b468..00c9c4ee49177 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -21,25 +21,31 @@ import java.util.UUID import com.google.common.primitives.Longs import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock +import org.apache.spark.util.Utils /** - * Constructs the complete list of driver configuration steps to run to deploy the Spark driver. + * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to + * configure the Spark driver pod. The returned steps will be applied one by one in the given + * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication + * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to + * configure the driver init-container if one is needed, i.e., when there are remote dependencies + * to localize. */ -private[spark] class DriverConfigurationStepsOrchestrator( - namespace: String, +private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, launchTime: Long, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) { + sparkConf: SparkConf) { // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the @@ -49,13 +55,14 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") } - private val imagePullPolicy = submissionSparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val jarsDownloadPath = submissionSparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = submissionSparkConf.get(FILES_DOWNLOAD_LOCATION) + private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" + private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) + private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) - def getAllConfigurationSteps(): Seq[DriverConfigurationStep] = { - val driverCustomLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, + def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + @@ -64,11 +71,15 @@ private[spark] class DriverConfigurationStepsOrchestrator( s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + "operations.") + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + val allDriverLabels = driverCustomLabels ++ Map( SPARK_APP_ID_LABEL -> kubernetesAppId, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - val initialSubmissionStep = new BaseDriverConfigurationStep( + val initialSubmissionStep = new BasicDriverConfigurationStep( kubernetesAppId, kubernetesResourceNamePrefix, allDriverLabels, @@ -76,16 +87,16 @@ private[spark] class DriverConfigurationStepsOrchestrator( appName, mainClass, appArgs, - submissionSparkConf) + sparkConf) - val driverAddressStep = new DriverServiceBootstrapStep( + val serviceBootstrapStep = new DriverServiceBootstrapStep( kubernetesResourceNamePrefix, allDriverLabels, - submissionSparkConf, + sparkConf, new SystemClock) val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, kubernetesResourceNamePrefix) + sparkConf, kubernetesResourceNamePrefix) val additionalMainAppJar = if (mainAppResource.nonEmpty) { val mayBeResource = mainAppResource.get match { @@ -98,28 +109,62 @@ private[spark] class DriverConfigurationStepsOrchestrator( None } - val sparkJars = submissionSparkConf.getOption("spark.jars") + val sparkJars = sparkConf.getOption("spark.jars") .map(_.split(",")) .getOrElse(Array.empty[String]) ++ additionalMainAppJar.toSeq - val sparkFiles = submissionSparkConf.getOption("spark.files") + val sparkFiles = sparkConf.getOption("spark.files") .map(_.split(",")) .getOrElse(Array.empty[String]) - val maybeDependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Some(new DependencyResolutionStep( + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { + Seq(new DependencyResolutionStep( sparkJars, sparkFiles, jarsDownloadPath, filesDownloadPath)) } else { - None + Nil + } + + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { + val orchestrator = new InitContainerConfigOrchestrator( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + imagePullPolicy, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME, + sparkConf) + val bootstrapStep = new DriverInitContainerBootstrapStep( + orchestrator.getAllConfigurationSteps, + initContainerConfigMapName, + INIT_CONTAINER_PROPERTIES_FILE_NAME) + + Seq(bootstrapStep) + } else { + Nil + } + + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil } Seq( initialSubmissionStep, - driverAddressStep, + serviceBootstrapStep, kubernetesCredentialsStep) ++ - maybeDependencyResolutionStep.toSeq + dependencyResolutionStep ++ + initContainerBootstrapStep ++ + mountSecretsStep + } + + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme != "local" + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 240a1144577b0..5884348cb3e41 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -80,22 +80,22 @@ private[spark] object ClientArguments { * spark.kubernetes.submission.waitAppCompletion is true. * * @param submissionSteps steps that collectively configure the driver - * @param submissionSparkConf the submission client Spark configuration + * @param sparkConf the submission client Spark configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete * @param appName the application name - * @param loggingPodStatusWatcher a watcher that monitors and logs the application status + * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( submissionSteps: Seq[DriverConfigurationStep], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - loggingPodStatusWatcher: LoggingPodStatusWatcher) extends Logging { + watcher: LoggingPodStatusWatcher) extends Logging { - private val driverJavaOptions = submissionSparkConf.get( + private val driverJavaOptions = sparkConf.get( org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) /** @@ -104,7 +104,7 @@ private[spark] class Client( * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(submissionSparkConf) + var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) // submissionSteps contain steps necessary to take, to resolve varying // client arguments that are passed in, created by orchestrator for (nextStep <- submissionSteps) { @@ -141,7 +141,7 @@ private[spark] class Client( kubernetesClient .pods() .withName(resolvedDriverPod.getMetadata.getName) - .watch(loggingPodStatusWatcher)) { _ => + .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { @@ -157,7 +157,7 @@ private[spark] class Client( if (waitForAppCompletion) { logInfo(s"Waiting for application $appName to finish...") - loggingPodStatusWatcher.awaitCompletion() + watcher.awaitCompletion() logInfo(s"Application $appName finished.") } else { logInfo(s"Deployed Spark application $appName into Kubernetes.") @@ -207,11 +207,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val master = sparkConf.get("spark.master").substring("k8s://".length) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None - val loggingPodStatusWatcher = new LoggingPodStatusWatcherImpl( - kubernetesAppId, loggingInterval) + val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val configurationStepsOrchestrator = new DriverConfigurationStepsOrchestrator( - namespace, + val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, launchTime, clientArguments.mainAppResource, @@ -228,12 +226,12 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - configurationStepsOrchestrator.getAllConfigurationSteps(), + orchestrator.getAllConfigurationSteps, sparkConf, kubernetesClient, waitForAppCompletion, appName, - loggingPodStatusWatcher) + watcher) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala similarity index 70% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index c335fcce4036e..b7a69a7dfd472 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -22,49 +22,46 @@ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarS import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} /** - * Represents the initial setup required for the driver. + * Performs basic configuration for the driver pod. */ -private[spark] class BaseDriverConfigurationStep( +private[spark] class BasicDriverConfigurationStep( kubernetesAppId: String, - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], imagePullPolicy: String, appName: String, mainClass: String, appArgs: Array[String], - submissionSparkConf: SparkConf) extends DriverConfigurationStep { + sparkConf: SparkConf) extends DriverConfigurationStep { - private val kubernetesDriverPodName = submissionSparkConf.get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$kubernetesResourceNamePrefix-driver") + private val driverPodName = sparkConf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"$resourceNamePrefix-driver") - private val driverExtraClasspath = submissionSparkConf.get( - DRIVER_CLASS_PATH) + private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - private val driverContainerImage = submissionSparkConf + private val driverContainerImage = sparkConf .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) // CPU settings - private val driverCpuCores = submissionSparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = submissionSparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) + private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") + private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) // Memory settings - private val driverMemoryMiB = submissionSparkConf.get( - DRIVER_MEMORY) - private val driverMemoryString = submissionSparkConf.get( - DRIVER_MEMORY.key, - DRIVER_MEMORY.defaultValueString) - private val memoryOverheadMiB = submissionSparkConf + private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) + private val driverMemoryString = sparkConf.get( + DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) + private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val driverContainerMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => @@ -74,15 +71,13 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val driverCustomAnnotations = ConfigurationUtils - .parsePrefixedKeyValuePairs( - submissionSparkConf, - KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + " Spark bookkeeping operations.") - val driverCustomEnvs = submissionSparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq + val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq .map { env => new EnvVarBuilder() .withName(env._1) @@ -90,10 +85,10 @@ private[spark] class BaseDriverConfigurationStep( .build() } - val allDriverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) + val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - val nodeSelector = ConfigurationUtils.parsePrefixedKeyValuePairs( - submissionSparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) val driverCpuQuantity = new QuantityBuilder(false) .withAmount(driverCpuCores) @@ -102,7 +97,7 @@ private[spark] class BaseDriverConfigurationStep( .withAmount(s"${driverMemoryMiB}Mi") .build() val driverMemoryLimitQuantity = new QuantityBuilder(false) - .withAmount(s"${driverContainerMemoryWithOverheadMiB}Mi") + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) @@ -142,9 +137,9 @@ private[spark] class BaseDriverConfigurationStep( val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() - .withName(kubernetesDriverPodName) + .withName(driverPodName) .addToLabels(driverLabels.asJava) - .addToAnnotations(allDriverAnnotations.asJava) + .addToAnnotations(driverAnnotations.asJava) .endMetadata() .withNewSpec() .withRestartPolicy("Never") @@ -153,9 +148,9 @@ private[spark] class BaseDriverConfigurationStep( .build() val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, kubernetesDriverPodName) + .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, kubernetesResourceNamePrefix) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) driverSpec.copy( driverPod = baseDriverPod, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index 44e0ecffc0e93..d4b83235b4e3b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -21,7 +21,8 @@ import java.io.File import io.fabric8.kubernetes.api.model.ContainerBuilder import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, KubernetesFileUtils} +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** * Step that configures the classpath, spark.jars, and spark.files for the driver given that the @@ -31,21 +32,22 @@ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], sparkFiles: Seq[String], jarsDownloadPath: String, - localFilesDownloadPath: String) extends DriverConfigurationStep { + filesDownloadPath: String) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesFileUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesFileUtils.resolveFileUris( - sparkFiles, localFilesDownloadPath) - val sparkConfResolvedSparkDependencies = driverSpec.driverSparkConf.clone() + val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) + val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + + val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.jars", resolvedSparkJars.mkString(",")) + sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) } if (resolvedSparkFiles.nonEmpty) { - sparkConfResolvedSparkDependencies.set("spark.files", resolvedSparkFiles.mkString(",")) + sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - val resolvedClasspath = KubernetesFileUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val driverContainerWithResolvedClasspath = if (resolvedClasspath.nonEmpty) { + + val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) + val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() .withName(ENV_MOUNTED_CLASSPATH) @@ -55,8 +57,9 @@ private[spark] class DependencyResolutionStep( } else { driverSpec.driverContainer } + driverSpec.copy( - driverContainer = driverContainerWithResolvedClasspath, - driverSparkConf = sparkConfResolvedSparkDependencies) + driverContainer = resolvedDriverContainer, + driverSparkConf = sparkConf) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala index c99c0436cf25f..17614e040e587 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec /** - * Represents a step in preparing the Kubernetes driver. + * Represents a step in configuring the Spark driver pod. */ private[spark] trait DriverConfigurationStep { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala new file mode 100644 index 0000000000000..9fb3dafdda540 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringWriter +import java.util.Properties + +import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} + +/** + * Configures the driver init-container that localizes remote dependencies into the driver pod. + * It applies the given InitContainerConfigurationSteps in the given order to produce a final + * InitContainerSpec that is then used to configure the driver pod with the init-container attached. + * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries + * configuration properties for the init-container. + */ +private[spark] class DriverInitContainerBootstrapStep( + steps: Seq[InitContainerConfigurationStep], + configMapName: String, + configMapKey: String) + extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + var initContainerSpec = InitContainerSpec( + properties = Map.empty[String, String], + driverSparkConf = Map.empty[String, String], + initContainer = new ContainerBuilder().build(), + driverContainer = driverSpec.driverContainer, + driverPod = driverSpec.driverPod, + dependentResources = Seq.empty[HasMetadata]) + for (nextStep <- steps) { + initContainerSpec = nextStep.configureInitContainer(initContainerSpec) + } + + val configMap = buildConfigMap( + configMapName, + configMapKey, + initContainerSpec.properties) + val resolvedDriverSparkConf = driverSpec.driverSparkConf + .clone() + .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) + .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) + .setAll(initContainerSpec.driverSparkConf) + val resolvedDriverPod = KubernetesUtils.appendInitContainer( + initContainerSpec.driverPod, initContainerSpec.initContainer) + + driverSpec.copy( + driverPod = resolvedDriverPod, + driverContainer = initContainerSpec.driverContainer, + driverSparkConf = resolvedDriverSparkConf, + otherKubernetesResources = + driverSpec.otherKubernetesResources ++ + initContainerSpec.dependentResources ++ + Seq(configMap)) + } + + private def buildConfigMap( + configMapName: String, + configMapKey: String, + config: Map[String, String]): ConfigMap = { + val properties = new Properties() + config.foreach { entry => + properties.setProperty(entry._1, entry._2) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName " + + s"and config map key: $configMapKey") + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .endMetadata() + .addToData(configMapKey, propertiesWriter.toString) + .build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala new file mode 100644 index 0000000000000..f872e0f4b65d1 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec + +/** + * A driver configuration step for mounting user-specified secrets onto user-specified paths. + * + * @param bootstrap a utility actually handling mounting of the secrets. + */ +private[spark] class DriverMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { + + override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { + val (pod, container) = bootstrap.mountSecrets( + driverSpec.driverPod, driverSpec.driverContainer) + driverSpec.copy( + driverPod = pod, + driverContainer = container + ) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index 696d11f15ed95..eb594e4f16ec0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -32,21 +32,22 @@ import org.apache.spark.util.Clock * ports should correspond to the ports that the executor will reach the pod at for RPC. */ private[spark] class DriverServiceBootstrapStep( - kubernetesResourceNamePrefix: String, + resourceNamePrefix: String, driverLabels: Map[String, String], - submissionSparkConf: SparkConf, + sparkConf: SparkConf, clock: Clock) extends DriverConfigurationStep with Logging { + import DriverServiceBootstrapStep._ override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(submissionSparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + "address is managed and set to the driver pod's IP address.") - require(submissionSparkConf.getOption(DRIVER_HOST_KEY).isEmpty, + require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + "managed via a Kubernetes service.") - val preferredServiceName = s"$kubernetesResourceNamePrefix$DRIVER_SVC_POSTFIX" + val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { preferredServiceName } else { @@ -58,8 +59,8 @@ private[spark] class DriverServiceBootstrapStep( shorterServiceName } - val driverPort = submissionSparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = submissionSparkConf.getInt( + val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = sparkConf.getInt( org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) val driverService = new ServiceBuilder() .withNewMetadata() @@ -81,7 +82,7 @@ private[spark] class DriverServiceBootstrapStep( .endSpec() .build() - val namespace = submissionSparkConf.get(KUBERNETES_NAMESPACE) + val namespace = sparkConf.get(KUBERNETES_NAMESPACE) val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala new file mode 100644 index 0000000000000..01469853dacc2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils + +/** + * Performs basic configuration for the driver init-container with most of the work delegated to + * the given InitContainerBootstrap. + */ +private[spark] class BasicInitContainerConfigurationStep( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + bootstrap: InitContainerBootstrap) + extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { + val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) + val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) + val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) + } else { + Map() + } + val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { + Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) + } else { + Map() + } + + val baseInitContainerConfig = Map( + JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, + FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ + remoteJarsConf ++ + remoteFilesConf + + val bootstrapped = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + spec.driverPod, + spec.initContainer, + spec.driverContainer)) + + spec.copy( + initContainer = bootstrapped.initContainer, + driverContainer = bootstrapped.mainContainer, + driverPod = bootstrapped.pod, + properties = spec.properties ++ baseInitContainerConfig) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala new file mode 100644 index 0000000000000..f2c29c7ce1076 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to + * configure the driver init-container. The returned steps will be applied in the given order to + * produce a final InitContainerSpec that is used to construct the driver init-container in + * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., + * when there are remote application dependencies to localize. + */ +private[spark] class InitContainerConfigOrchestrator( + sparkJars: Seq[String], + sparkFiles: Seq[String], + jarsDownloadPath: String, + filesDownloadPath: String, + imagePullPolicy: String, + configMapName: String, + configMapKey: String, + sparkConf: SparkConf) { + + private val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + + def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { + val initContainerBootstrap = new InitContainerBootstrap( + initContainerImage, + imagePullPolicy, + jarsDownloadPath, + filesDownloadPath, + configMapName, + configMapKey, + SPARK_POD_DRIVER_ROLE, + sparkConf) + val baseStep = new BasicInitContainerConfigurationStep( + sparkJars, + sparkFiles, + jarsDownloadPath, + filesDownloadPath, + initContainerBootstrap) + + val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_DRIVER_SECRETS_PREFIX) + // Mount user-specified driver secrets also into the driver's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The driver's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + + Seq(baseStep) ++ mountSecretsStep + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala new file mode 100644 index 0000000000000..0372ad5270951 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +/** + * Represents a step in configuring the driver init-container. + */ +private[spark] trait InitContainerConfigurationStep { + + def configureInitContainer(spec: InitContainerSpec): InitContainerSpec +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala new file mode 100644 index 0000000000000..c0e7bb20cce8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.deploy.k8s.MountSecretsBootstrap + +/** + * An init-container configuration step for mounting user-specified secrets onto user-specified + * paths. + * + * @param bootstrap a utility actually handling mounting of the secrets + */ +private[spark] class InitContainerMountSecretsStep( + bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { + + override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { + val (driverPod, initContainer) = bootstrap.mountSecrets( + spec.driverPod, + spec.initContainer) + spec.copy( + driverPod = driverPod, + initContainer = initContainer + ) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala new file mode 100644 index 0000000000000..b52c343f0c0ed --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} + +/** + * Represents a specification of the init-container for the driver pod. + * + * @param properties properties that should be set on the init-container + * @param driverSparkConf Spark configuration properties that will be carried back to the driver + * @param initContainer the init-container object + * @param driverContainer the driver container object + * @param driverPod the driver pod object + * @param dependentResources resources the init-container depends on to work + */ +private[spark] case class InitContainerSpec( + properties: Map[String, String], + driverSparkConf: Map[String, String], + initContainer: Container, + driverContainer: Container, + driverPod: Pod, + dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala new file mode 100644 index 0000000000000..4a4b628aedbbf --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.rest.k8s + +import java.io.File +import java.util.concurrent.TimeUnit + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Process that fetches files from a resource staging server and/or arbitrary remote locations. + * + * The init-container can handle fetching files from any of those sources, but not all of the + * sources need to be specified. This allows for composing multiple instances of this container + * with different configurations for different download sources, or using the same container to + * download everything at once. + */ +private[spark] class SparkPodInitContainer( + sparkConf: SparkConf, + fileFetcher: FileFetcher) extends Logging { + + private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) + private implicit val downloadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) + + private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) + private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) + + private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) + private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) + + private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) + + def run(): Unit = { + logInfo(s"Downloading remote jars: $remoteJars") + downloadFiles( + remoteJars, + jarsDownloadDir, + s"Remote jars download directory specified at $jarsDownloadDir does not exist " + + "or is not a directory.") + + logInfo(s"Downloading remote files: $remoteFiles") + downloadFiles( + remoteFiles, + filesDownloadDir, + s"Remote files download directory specified at $filesDownloadDir does not exist " + + "or is not a directory.") + + downloadExecutor.shutdown() + downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) + } + + private def downloadFiles( + filesCommaSeparated: Option[String], + downloadDir: File, + errMessage: String): Unit = { + filesCommaSeparated.foreach { files => + require(downloadDir.isDirectory, errMessage) + Utils.stringToSeq(files).foreach { file => + Future[Unit] { + fileFetcher.fetchFile(file, downloadDir) + } + } + } + } +} + +private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { + + def fetchFile(uri: String, targetDir: File): Unit = { + Utils.fetchFile( + url = uri, + targetDir = targetDir, + conf = sparkConf, + securityMgr = securityManager, + hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), + timestamp = System.currentTimeMillis(), + useCache = false) + } +} + +object SparkPodInitContainer extends Logging { + + def main(args: Array[String]): Unit = { + logInfo("Starting init-container to download Spark application dependencies.") + val sparkConf = new SparkConf(true) + if (args.nonEmpty) { + Utils.loadDefaultSparkProperties(sparkConf, args(0)) + } + + val securityManager = new SparkSecurityManager(sparkConf) + val fileFetcher = new FileFetcher(sparkConf, securityManager) + new SparkPodInitContainer(sparkConf, fileFetcher).run() + logInfo("Finished downloading application dependencies.") + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 70226157dd68b..ba5d891f4c77e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,35 +21,35 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.ConfigurationUtils import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} import org.apache.spark.util.Utils /** - * A factory class for configuring and creating executor pods. + * A factory class for bootstrapping and creating executor pods with the given bootstrapping + * components. + * + * @param sparkConf Spark configuration + * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto + * user-specified paths into the executor container + * @param initContainerBootstrap an optional component for bootstrapping the executor init-container + * if one is needed, i.e., when there are remote dependencies to + * localize + * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified + * secrets onto user-specified paths into the executor + * init-container */ -private[spark] trait ExecutorPodFactory { - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod -} - -private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) - extends ExecutorPodFactory { +private[spark] class ExecutorPodFactory( + sparkConf: SparkConf, + mountSecretsBootstrap: Option[MountSecretsBootstrap], + initContainerBootstrap: Option[InitContainerBootstrap], + initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - private val executorLabels = ConfigurationUtils.parsePrefixedKeyValuePairs( + private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( @@ -64,11 +64,11 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") private val executorAnnotations = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) private val nodeSelector = - ConfigurationUtils.parsePrefixedKeyValuePairs( + KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) @@ -94,7 +94,10 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - override def createExecutorPod( + /** + * Configure and construct an executor pod with the given parameters. + */ + def createExecutorPod( executorId: String, applicationId: String, driverUrl: String, @@ -198,7 +201,7 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .endSpec() .build() - val containerWithExecutorLimitCores = executorLimitCores.map { limitCores => + val containerWithLimitCores = executorLimitCores.map { limitCores => val executorCpuLimitQuantity = new QuantityBuilder(false) .withAmount(limitCores) .build() @@ -209,9 +212,33 @@ private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) .build() }.getOrElse(executorContainer) - new PodBuilder(executorPod) + val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = + mountSecretsBootstrap.map { bootstrap => + bootstrap.mountSecrets(executorPod, containerWithLimitCores) + }.getOrElse((executorPod, containerWithLimitCores)) + + val (bootstrappedPod, bootstrappedContainer) = + initContainerBootstrap.map { bootstrap => + val podWithInitContainer = bootstrap.bootstrapInitContainer( + PodWithDetachedInitContainer( + maybeSecretsMountedPod, + new ContainerBuilder().build(), + maybeSecretsMountedContainer)) + + val (pod, mayBeSecretsMountedInitContainer) = + initContainerMountSecretsBootstrap.map { bootstrap => + bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) + + val bootstrappedPod = KubernetesUtils.appendInitContainer( + pod, mayBeSecretsMountedInitContainer) + + (bootstrappedPod, podWithInitContainer.mainContainer) + }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) + + new PodBuilder(bootstrappedPod) .editSpec() - .addToContainers(containerWithExecutorLimitCores) + .addToContainers(bootstrappedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b8bb152d17910..a942db6ae02db 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,9 +21,9 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.util.ThreadUtils @@ -45,6 +45,59 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { val sparkConf = sc.getConf + val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) + val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) + + if (initContainerConfigMap.isEmpty) { + logWarning("The executor's init-container config map is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + if (initContainerConfigMapKey.isEmpty) { + logWarning("The executor's init-container config map key is not specified. Executors will " + + "therefore not attempt to fetch remote or submitted dependencies.") + } + + // Only set up the bootstrap if they've provided both the config map key and the config map + // name. The config map might not be provided if init-containers aren't being used to + // bootstrap dependencies. + val initContainerBootstrap = for { + configMap <- initContainerConfigMap + configMapKey <- initContainerConfigMapKey + } yield { + val initContainerImage = sparkConf + .get(INIT_CONTAINER_IMAGE) + .getOrElse(throw new SparkException( + "Must specify the init-container image when there are remote dependencies")) + new InitContainerBootstrap( + initContainerImage, + sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), + sparkConf.get(JARS_DOWNLOAD_LOCATION), + sparkConf.get(FILES_DOWNLOAD_LOCATION), + configMap, + configMapKey, + SPARK_POD_EXECUTOR_ROLE, + sparkConf) + } + + val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } + // Mount user-specified executor secrets also into the executor's init-container. The + // init-container may need credentials in the secrets to be able to download remote + // dependencies. The executor's main container and its init-container share the secrets + // because the init-container is sort of an implementation details and this sharing + // avoids introducing a dedicated configuration property just for the init-container. + val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && + executorSecretNamesToMountPaths.nonEmpty) { + Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) + } else { + None + } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, @@ -54,7 +107,12 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val executorPodFactory = new ExecutorPodFactory( + sparkConf, + mountSecretBootstrap, + initContainerBootstrap, + initContainerMountSecretsBootstrap) + val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala similarity index 51% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 98f9f27da5cde..f193b1f4d3664 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigurationStepsOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config.DRIVER_CONTAINER_IMAGE +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ -class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { +class DriverConfigOrchestratorSuite extends SparkFunSuite { - private val NAMESPACE = "default" private val DRIVER_IMAGE = "driver-image" + private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" private val LAUNCH_TIME = 975256L private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Some(mainAppResource), @@ -45,7 +47,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], classOf[DependencyResolutionStep] @@ -55,8 +57,7 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false) .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigurationStepsOrchestrator( - NAMESPACE, + val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, Option.empty, @@ -66,16 +67,62 @@ class DriverConfigurationStepsOrchestratorSuite extends SparkFunSuite { sparkConf) validateStepTypes( orchestrator, - classOf[BaseDriverConfigurationStep], + classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep] ) } + test("Submission steps with an init-container.") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverInitContainerBootstrapStep]) + } + + test("Submission steps with driver secrets to mount") { + val sparkConf = new SparkConf(false) + .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") + val orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(mainAppResource), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + validateStepTypes( + orchestrator, + classOf[BasicDriverConfigurationStep], + classOf[DriverServiceBootstrapStep], + classOf[DriverKubernetesCredentialsStep], + classOf[DependencyResolutionStep], + classOf[DriverMountSecretsStep]) + } + private def validateStepTypes( - orchestrator: DriverConfigurationStepsOrchestrator, + orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps() + val steps = orchestrator.getAllConfigurationSteps assert(steps.size === types.size) assert(steps.map(_.getClass) === types) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala new file mode 100644 index 0000000000000..8388c16ded268 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, Pod} + +private[spark] object SecretVolumeUtils { + + def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { + driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + } + + def containerHasVolume( + driverContainer: Container, + volumeName: String, + mountPath: String): Boolean = { + driverContainer.getVolumeMounts.asScala.exists(volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala similarity index 97% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index f7c1b3142cf71..e864c6a16eeb1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BaseDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -class BaseDriverConfigurationStepSuite extends SparkFunSuite { +class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val APP_ID = "spark-app-id" private val RESOURCE_NAME_PREFIX = "spark" @@ -52,7 +52,7 @@ class BaseDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - val submissionStep = new BaseDriverConfigurationStep( + val submissionStep = new BasicDriverConfigurationStep( APP_ID, RESOURCE_NAME_PREFIX, DRIVER_LABELS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala new file mode 100644 index 0000000000000..758871e2ba356 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import java.io.StringReader +import java.util.Properties + +import scala.collection.JavaConverters._ + +import com.google.common.collect.Maps +import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} +import org.apache.spark.util.Utils + +class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { + + private val CONFIG_MAP_NAME = "spark-init-config-map" + private val CONFIG_MAP_KEY = "spark-init-config-map-key" + + test("The init container bootstrap step should use all of the init container steps") { + val baseDriverSpec = KubernetesDriverSpec( + driverPod = new PodBuilder().build(), + driverContainer = new ContainerBuilder().build(), + driverSparkConf = new SparkConf(false), + otherKubernetesResources = Seq.empty[HasMetadata]) + val initContainerSteps = Seq( + FirstTestInitContainerConfigurationStep, + SecondTestInitContainerConfigurationStep) + val bootstrapStep = new DriverInitContainerBootstrapStep( + initContainerSteps, + CONFIG_MAP_NAME, + CONFIG_MAP_KEY) + + val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) + + assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === + FirstTestInitContainerConfigurationStep.additionalLabels) + val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala + assert(additionalDriverEnv.size === 1) + assert(additionalDriverEnv.head.getName === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) + assert(additionalDriverEnv.head.getValue === + FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) + + assert(preparedDriverSpec.otherKubernetesResources.size === 2) + assert(preparedDriverSpec.otherKubernetesResources.contains( + FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) + assert(preparedDriverSpec.otherKubernetesResources.exists { + case configMap: ConfigMap => + val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME + val configMapData = configMap.getData.asScala + val hasCorrectNumberOfEntries = configMapData.size == 1 + val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) + val initContainerProperties = new Properties() + Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { + initContainerProperties.load(_) + } + val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala + val expectedInitContainerProperties = Map( + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> + SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) + val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties + hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties + + case _ => false + }) + + val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers + assert(initContainers.size() === 1) + val initContainerEnv = initContainers.get(0).getEnv.asScala + assert(initContainerEnv.size === 1) + assert(initContainerEnv.head.getName === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) + assert(initContainerEnv.head.getValue === + SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) + + val expectedSparkConf = Map( + INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> + SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) + assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) + } +} + +private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + + val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") + val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" + val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" + val additionalKubernetesResource = new SecretBuilder() + .withNewMetadata() + .withName("test-secret") + .endMetadata() + .addToData("secret-key", "secret-value") + .build() + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val driverPod = new PodBuilder(initContainerSpec.driverPod) + .editOrNewMetadata() + .addToLabels(additionalLabels.asJava) + .endMetadata() + .build() + val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) + .addNewEnv() + .withName(additionalMainContainerEnvKey) + .withValue(additionalMainContainerEnvValue) + .endEnv() + .build() + initContainerSpec.copy( + driverPod = driverPod, + driverContainer = mainContainer, + dependentResources = initContainerSpec.dependentResources ++ + Seq(additionalKubernetesResource)) + } +} + +private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { + val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" + val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" + val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" + val additionalInitContainerPropertyValue = "testvalue" + val additionalDriverSparkConfKey = "spark.driver.testkey" + val additionalDriverSparkConfValue = "spark.driver.testvalue" + + override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { + val initContainer = new ContainerBuilder(initContainerSpec.initContainer) + .addNewEnv() + .withName(additionalInitContainerEnvKey) + .withValue(additionalInitContainerEnvValue) + .endEnv() + .build() + val initContainerProperties = initContainerSpec.properties ++ + Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) + val driverSparkConf = initContainerSpec.driverSparkConf ++ + Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) + initContainerSpec.copy( + initContainer = initContainer, + properties = initContainerProperties, + driverSparkConf = driverSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala new file mode 100644 index 0000000000000..9ec0cb55de5aa --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} + +class DriverMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/driver" + + test("mounts all given secrets") { + val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) + val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) + val driverPodWithSecretsMounted = configuredDriverSpec.driverPod + val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) + } + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => + assert(SecretVolumeUtils.containerHasVolume( + driverContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala new file mode 100644 index 0000000000000..4553f9f6b1d45 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.Config._ + +class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val POD_LABEL = Map("bootstrap" -> "true") + private val INIT_CONTAINER_NAME = "init-container" + private val DRIVER_CONTAINER_NAME = "driver-container" + + @Mock + private var podAndInitContainerBootstrap : InitContainerBootstrap = _ + + before { + MockitoAnnotations.initMocks(this) + when(podAndInitContainerBootstrap.bootstrapInitContainer( + any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { + override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { + val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) + pod.copy( + pod = new PodBuilder(pod.pod) + .withNewMetadata() + .addToLabels("bootstrap", "true") + .endMetadata() + .withNewSpec().endSpec() + .build(), + initContainer = new ContainerBuilder() + .withName(INIT_CONTAINER_NAME) + .build(), + mainContainer = new ContainerBuilder() + .withName(DRIVER_CONTAINER_NAME) + .build() + )}}) + } + + test("additionalDriverSparkConf with mix of remote files and jars") { + val baseInitStep = new BasicInitContainerConfigurationStep( + SPARK_JARS, + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + podAndInitContainerBootstrap) + val expectedDriverSparkConf = Map( + JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, + INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", + INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") + val initContainerSpec = InitContainerSpec( + Map.empty[String, String], + Map.empty[String, String], + new Container(), + new Container(), + new Pod, + Seq.empty[HasMetadata]) + val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) + assert(expectedDriverSparkConf === returnContainerSpec.properties) + assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) + assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) + assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala new file mode 100644 index 0000000000000..20f2e5bc15df3 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class InitContainerConfigOrchestratorSuite extends SparkFunSuite { + + private val DOCKER_IMAGE = "init-container" + private val SPARK_JARS = Seq( + "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") + private val SPARK_FILES = Seq( + "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") + private val JARS_DOWNLOAD_PATH = "/var/data/jars" + private val FILES_DOWNLOAD_PATH = "/var/data/files" + private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" + private val CUSTOM_LABEL_KEY = "customLabel" + private val CUSTOM_LABEL_VALUE = "customLabelValue" + private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" + private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("including basic configuration step") { + val sparkConf = new SparkConf(true) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.lengthCompare(1) == 0) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + } + + test("including step to mount user-specified secrets") { + val sparkConf = new SparkConf(false) + .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) + .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) + + val orchestrator = new InitContainerConfigOrchestrator( + SPARK_JARS.take(1), + SPARK_FILES, + JARS_DOWNLOAD_PATH, + FILES_DOWNLOAD_PATH, + DOCKER_IMAGE_PULL_POLICY, + INIT_CONTAINER_CONFIG_MAP_NAME, + INIT_CONTAINER_CONFIG_MAP_KEY, + sparkConf) + val initSteps = orchestrator.getAllConfigurationSteps + assert(initSteps.length === 2) + assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) + assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala new file mode 100644 index 0000000000000..eab4e17659456 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit.steps.initcontainer + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.MountSecretsBootstrap +import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils + +class InitContainerMountSecretsStepSuite extends SparkFunSuite { + + private val SECRET_FOO = "foo" + private val SECRET_BAR = "bar" + private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" + + test("mounts all given secrets") { + val baseInitContainerSpec = InitContainerSpec( + Map.empty, + Map.empty, + new ContainerBuilder().build(), + new ContainerBuilder().build(), + new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), + Seq.empty) + val secretNamesToMountPaths = Map( + SECRET_FOO -> SECRET_MOUNT_PATH, + SECRET_BAR -> SECRET_MOUNT_PATH) + + val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) + val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) + val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( + baseInitContainerSpec) + + val podWithSecretsMounted = configuredInitContainerSpec.driverPod + val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer + + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => + assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) + Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => + assert(SecretVolumeUtils.containerHasVolume( + initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala new file mode 100644 index 0000000000000..6c557ec4a7c9a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.rest.k8s + +import java.io.File +import java.util.UUID + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter +import org.scalatest.mockito.MockitoSugar._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.Utils + +class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { + + private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") + private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") + + private var downloadJarsDir: File = _ + private var downloadFilesDir: File = _ + private var downloadJarsSecretValue: String = _ + private var downloadFilesSecretValue: String = _ + private var fileFetcher: FileFetcher = _ + + override def beforeAll(): Unit = { + downloadJarsSecretValue = Files.toString( + new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) + downloadFilesSecretValue = Files.toString( + new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) + } + + before { + downloadJarsDir = Utils.createTempDir() + downloadFilesDir = Utils.createTempDir() + fileFetcher = mock[FileFetcher] + } + + after { + downloadJarsDir.delete() + downloadFilesDir.delete() + } + + test("Downloads from remote server should invoke the file fetcher") { + val sparkConf = getSparkConfForRemoteFileDownloads + val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) + initContainerUnderTest.run() + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) + Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) + } + + private def getSparkConfForRemoteFileDownloads: SparkConf = { + new SparkConf(true) + .set(INIT_CONTAINER_REMOTE_JARS, + "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") + .set(INIT_CONTAINER_REMOTE_FILES, + "http://localhost:9000/file.txt") + .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) + .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) + } + + private def createTempFile(extension: String): String = { + val dir = Utils.createTempDir() + val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") + Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) + file.getAbsolutePath + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 3a55d7cb37b1f..7121a802c69c1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -18,15 +18,19 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{Pod, _} -import org.mockito.MockitoAnnotations +import io.fabric8.kubernetes.api.model._ +import org.mockito.{AdditionalAnswers, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + private val driverPodName: String = "driver-pod" private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" @@ -54,7 +58,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactoryImpl(baseConf) + val factory = new ExecutorPodFactory(baseConf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -85,7 +89,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -97,7 +101,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactoryImpl(conf) + val factory = new ExecutorPodFactory(conf, None, None, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -108,6 +112,74 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor secrets get mounted") { + val conf = baseConf.clone() + + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + val factory = new ExecutorPodFactory( + conf, + Some(secretsBootstrap), + None, + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName + === "secret1-volume") + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) + .getMountPath === "/var/secret1") + + // check volume mounted. + assert(executor.getSpec.getVolumes.size() === 1) + assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") + + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container bootstrap step adds an init container") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + + val factory = new ExecutorPodFactory( + conf, + None, + Some(initContainerBootstrap), + None) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getInitContainers.size() === 1) + checkOwnerReferences(executor, driverPodUid) + } + + test("init-container with secrets mount bootstrap") { + val conf = baseConf.clone() + val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) + when(initContainerBootstrap.bootstrapInitContainer( + any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) + + val factory = new ExecutorPodFactory( + conf, + None, + Some(initContainerBootstrap), + Some(secretsBootstrap)) + val executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getInitContainers.size() === 1) + assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName + === "secret1-volume") + assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) + .getMountPath === "/var/secret1") + + checkOwnerReferences(executor, driverPodUid) + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 9b682f8673c69..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 168cd4cb6c57a..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -22,7 +22,7 @@ FROM spark-base # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . COPY examples /opt/spark/examples @@ -31,4 +31,5 @@ CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ + if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile new file mode 100644 index 0000000000000..055493188fcb7 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM spark-base + +# If this docker file is being used in the context of building your images from a Spark distribution, the docker build +# command should be invoked from the top level directory of the Spark distribution. E.g.: +# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . + +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] From ded6d27e4eb02e4530015a95794e6ed0586faaa7 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 28 Dec 2017 13:53:04 +0900 Subject: [PATCH 511/712] [SPARK-22648][K8S] Add documentation covering init containers and secrets ## What changes were proposed in this pull request? This PR updates the Kubernetes documentation corresponding to the following features/changes in #19954. * Ability to use remote dependencies through the init-container. * Ability to mount user-specified secrets into the driver and executor pods. vanzin jiangxb1987 foxish Author: Yinan Li Closes #20059 from liyinan926/doc-update. --- docs/running-on-kubernetes.md | 194 ++++++++++++++++++++++--------- sbin/build-push-docker-images.sh | 3 +- 2 files changed, 143 insertions(+), 54 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 0048bd90b48ae..e491329136a3c 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -69,17 +69,17 @@ building using the supplied script, or manually. To launch Spark Pi in cluster mode, -{% highlight bash %} +```bash $ bin/spark-submit \ --master k8s://https://: \ --deploy-mode cluster \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.docker.image= \ - --conf spark.kubernetes.executor.docker.image= \ + --conf spark.kubernetes.driver.container.image= \ + --conf spark.kubernetes.executor.container.image= \ local:///path/to/examples.jar -{% endhighlight %} +``` The Spark master, specified either via passing the `--master` command line argument to `spark-submit` or by setting `spark.master` in the application's configuration, must be a URL with the format `k8s://`. Prefixing the @@ -120,6 +120,54 @@ by their appropriate remote URIs. Also, application dependencies can be pre-moun Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the `SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +### Using Remote Dependencies +When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods +need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading +the dependencies so the driver and executor containers can use them locally. This requires users to specify the container +image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users +simply add the following option to the `spark-submit` command to specify the init-container image: + +``` +--conf spark.kubernetes.initContainer.image= +``` + +The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and +`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., +the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: + +```bash +$ bin/spark-submit \ + --master k8s://https://: \ + --deploy-mode cluster \ + --name spark-pi \ + --class org.apache.spark.examples.SparkPi \ + --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar + --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 + --conf spark.executor.instances=5 \ + --conf spark.kubernetes.driver.container.image= \ + --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.initContainer.image= + https://path/to/examples.jar +``` + +## Secret Management +Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a +Spark application to access secured services. To mount a user-specified secret into the driver container, users can use +the configuration property of the form `spark.kubernetes.driver.secrets.[SecretName]=`. Similarly, the +configuration property of the form `spark.kubernetes.executor.secrets.[SecretName]=` can be used to mount a +user-specified secret into the executor containers. Note that it is assumed that the secret to be mounted is in the same +namespace as that of the driver and executor pods. For example, to mount a secret named `spark-secret` onto the path +`/etc/secrets` in both the driver and executor containers, add the following options to the `spark-submit` command: + +``` +--conf spark.kubernetes.driver.secrets.spark-secret=/etc/secrets +--conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets +``` + +Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the +init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the +init-container of the executor. + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -275,7 +323,7 @@ specific to Spark on Kubernetes. (none) Container image to use for the driver. - This is usually of the form `example.com/repo/spark-driver:v1.0.0`. + This is usually of the form example.com/repo/spark-driver:v1.0.0. This configuration is required and must be provided by the user. @@ -284,7 +332,7 @@ specific to Spark on Kubernetes. (none) Container image to use for the executors. - This is usually of the form `example.com/repo/spark-executor:v1.0.0`. + This is usually of the form example.com/repo/spark-executor:v1.0.0. This configuration is required and must be provided by the user. @@ -528,51 +576,91 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.limit.cores - (none) - - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. - - - - spark.kubernetes.executor.limit.cores - (none) - - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. - - - - spark.kubernetes.node.selector.[labelKey] - (none) - - Adds to the node selector of the driver pod and executor pods, with key labelKey and the value as the - configuration's value. For example, setting spark.kubernetes.node.selector.identifier to myIdentifier - will result in the driver pod and executors having a node selector with key identifier and value - myIdentifier. Multiple node selector keys can be added by setting multiple configurations with this prefix. - - - - spark.kubernetes.driverEnv.[EnvironmentVariableName] - (none) - - Add the environment variable specified by EnvironmentVariableName to - the Driver process. The user can specify multiple of these to set multiple environment variables. - - - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - + spark.kubernetes.driver.limit.cores + (none) + + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + + + spark.kubernetes.executor.limit.cores + (none) + + Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + + + + spark.kubernetes.node.selector.[labelKey] + (none) + + Adds to the node selector of the driver pod and executor pods, with key labelKey and the value as the + configuration's value. For example, setting spark.kubernetes.node.selector.identifier to myIdentifier + will result in the driver pod and executors having a node selector with key identifier and value + myIdentifier. Multiple node selector keys can be added by setting multiple configurations with this prefix. + + + + spark.kubernetes.driverEnv.[EnvironmentVariableName] + (none) + + Add the environment variable specified by EnvironmentVariableName to + the Driver process. The user can specify multiple of these to set multiple environment variables. + + + + spark.kubernetes.mountDependencies.jarsDownloadDir + /var/spark-data/spark-jars + + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. + + + + spark.kubernetes.mountDependencies.filesDownloadDir + /var/spark-data/spark-files + + Location to download jars to in the driver and executors. + This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. + + + + spark.kubernetes.mountDependencies.timeout + 300s + + Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into + the driver and executor pods. + + + + spark.kubernetes.mountDependencies.maxSimultaneousDownloads + 5 + + Maximum number of remote dependencies to download simultaneously in a driver or executor pod. + + + + spark.kubernetes.initContainer.image + (none) + + Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + + + + spark.kubernetes.driver.secrets.[SecretName] + (none) + + Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, + the secret will also be added to the init-container in the driver pod. + + + + spark.kubernetes.executor.secrets.[SecretName] + (none) + + Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, + the secret will also be added to the init-container in the executor pod. + + \ No newline at end of file diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index 4546e98dc2074..b3137598692d8 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -20,7 +20,8 @@ # with Kubernetes support. declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile ) + [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ + [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) function build { docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . From 76e8a1d7e2619c1e6bd75c399314d2583a86b93b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 28 Dec 2017 20:17:26 +0900 Subject: [PATCH 512/712] [SPARK-22843][R] Adds localCheckpoint in R ## What changes were proposed in this pull request? This PR proposes to add `localCheckpoint(..)` in R API. ```r df <- localCheckpoint(createDataFrame(iris)) ``` ## How was this patch tested? Unit tests added in `R/pkg/tests/fulltests/test_sparkSQL.R` Author: hyukjinkwon Closes #20073 from HyukjinKwon/SPARK-22843. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 ++++++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index dce64e1e607c8..4b699de558b4f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -133,6 +133,7 @@ exportMethods("arrange", "isStreaming", "join", "limit", + "localCheckpoint", "merge", "mutate", "na.omit", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b8d732a485862..ace49daf9cd83 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3782,6 +3782,33 @@ setMethod("checkpoint", dataFrame(df) }) +#' localCheckpoint +#' +#' Returns a locally checkpointed version of this SparkDataFrame. Checkpointing can be used to +#' truncate the logical plan, which is especially useful in iterative algorithms where the plan +#' may grow exponentially. Local checkpoints are stored in the executors using the caching +#' subsystem and therefore they are not reliable. +#' +#' @param x A SparkDataFrame +#' @param eager whether to locally checkpoint this SparkDataFrame immediately +#' @return a new locally checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases localCheckpoint,SparkDataFrame-method +#' @rdname localCheckpoint +#' @name localCheckpoint +#' @export +#' @examples +#'\dontrun{ +#' df <- localCheckpoint(df) +#' } +#' @note localCheckpoint since 2.3.0 +setMethod("localCheckpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "localCheckpoint", as.logical(eager)) + dataFrame(df) + }) + #' cube #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5ddaa669f9205..d5d0bc9d8a97d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -611,6 +611,10 @@ setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' @rdname localCheckpoint +#' @export +setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) + #' @rdname merge #' @export setGeneric("merge") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 6cc0188dae95f..650e7c05f46e8 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -957,6 +957,28 @@ test_that("setCheckpointDir(), checkpoint() on a DataFrame", { } }) +test_that("localCheckpoint() on a DataFrame", { + if (windows_with_hadoop()) { + # Checkpoint directory shouldn't matter in localCheckpoint. + checkpointDir <- file.path(tempdir(), "lcproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + setCheckpointDir(checkpointDir) + + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + writeLines(mockLines, textPath) + # Read it lazily and then locally checkpoint eagerly. + df <- read.df(textPath, "text") + df <- localCheckpoint(df, eager = TRUE) + # Here, we remove the source path to check eagerness. + unlink(textPath) + expect_is(df, "SparkDataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + + expect_true(length(list.files(path = checkpointDir, all.files = TRUE, recursive = TRUE)) == 0) + } +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) From 1eebfbe192060af3c81cd086bc5d5a7e80d09e77 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 28 Dec 2017 20:18:47 +0900 Subject: [PATCH 513/712] [SPARK-21208][R] Adds setLocalProperty and getLocalProperty in R ## What changes were proposed in this pull request? This PR adds `setLocalProperty` and `getLocalProperty`in R. ```R > df <- createDataFrame(iris) > setLocalProperty("spark.job.description", "Hello world!") > count(df) > setLocalProperty("spark.job.description", "Hi !!") > count(df) ``` 2017-12-25 4 18 07 ```R > print(getLocalProperty("spark.job.description")) NULL > setLocalProperty("spark.job.description", "Hello world!") > print(getLocalProperty("spark.job.description")) [1] "Hello world!" > setLocalProperty("spark.job.description", "Hi !!") > print(getLocalProperty("spark.job.description")) [1] "Hi !!" ``` ## How was this patch tested? Manually tested and a test in `R/pkg/tests/fulltests/test_context.R`. Author: hyukjinkwon Closes #20075 from HyukjinKwon/SPARK-21208. --- R/pkg/NAMESPACE | 4 ++- R/pkg/R/sparkR.R | 45 ++++++++++++++++++++++++++++ R/pkg/tests/fulltests/test_context.R | 33 +++++++++++++++++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4b699de558b4f..ce3eec069efe5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -76,7 +76,9 @@ exportMethods("glm", export("setJobGroup", "clearJobGroup", "cancelJobGroup", - "setJobDescription") + "setJobDescription", + "setLocalProperty", + "getLocalProperty") # Export Utility methods export("setLogLevel") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index fb5f1d21fc723..965471f3b07a0 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -560,10 +560,55 @@ cancelJobGroup <- function(sc, groupId) { #'} #' @note setJobDescription since 2.3.0 setJobDescription <- function(value) { + if (!is.null(value)) { + value <- as.character(value) + } sc <- getSparkContext() invisible(callJMethod(sc, "setJobDescription", value)) } +#' Set a local property that affects jobs submitted from this thread, such as the +#' Spark fair scheduler pool. +#' +#' @param key The key for a local property. +#' @param value The value for a local property. +#' @rdname setLocalProperty +#' @name setLocalProperty +#' @examples +#'\dontrun{ +#' setLocalProperty("spark.scheduler.pool", "poolA") +#'} +#' @note setLocalProperty since 2.3.0 +setLocalProperty <- function(key, value) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + if (!is.null(value)) { + value <- as.character(value) + } + sc <- getSparkContext() + invisible(callJMethod(sc, "setLocalProperty", as.character(key), value)) +} + +#' Get a local property set in this thread, or \code{NULL} if it is missing. See +#' \code{setLocalProperty}. +#' +#' @param key The key for a local property. +#' @rdname getLocalProperty +#' @name getLocalProperty +#' @examples +#'\dontrun{ +#' getLocalProperty("spark.scheduler.pool") +#'} +#' @note getLocalProperty since 2.3.0 +getLocalProperty <- function(key) { + if (is.null(key) || is.na(key)) { + stop("key should not be NULL or NA.") + } + sc <- getSparkContext() + callJMethod(sc, "getLocalProperty", as.character(key)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 77635c5a256b9..f0d0a5114f89f 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,7 +100,6 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() - setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) @@ -108,6 +107,38 @@ test_that("job group functions can be called", { sparkR.session.stop() }) +test_that("job description and local properties can be set and got", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + setJobDescription("job description") + expect_equal(getLocalProperty("spark.job.description"), "job description") + setJobDescription(1234) + expect_equal(getLocalProperty("spark.job.description"), "1234") + setJobDescription(NULL) + expect_equal(getLocalProperty("spark.job.description"), NULL) + setJobDescription(NA) + expect_equal(getLocalProperty("spark.job.description"), NULL) + + setLocalProperty("spark.scheduler.pool", "poolA") + expect_equal(getLocalProperty("spark.scheduler.pool"), "poolA") + setLocalProperty("spark.scheduler.pool", NULL) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + setLocalProperty("spark.scheduler.pool", NA) + expect_equal(getLocalProperty("spark.scheduler.pool"), NULL) + + setLocalProperty(4321, 1234) + expect_equal(getLocalProperty(4321), "1234") + setLocalProperty(4321, NULL) + expect_equal(getLocalProperty(4321), NULL) + setLocalProperty(4321, NA) + expect_equal(getLocalProperty(4321), NULL) + + expect_error(setLocalProperty(NULL, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NULL), "key should not be NULL or NA") + expect_error(setLocalProperty(NA, "should fail"), "key should not be NULL or NA") + expect_error(getLocalProperty(NA), "key should not be NULL or NA") + sparkR.session.stop() +}) + test_that("utility function can be called", { sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") From 755f2f5189a08597fddc90b7f9df4ad0ec6bd2ef Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 28 Dec 2017 21:33:03 +0800 Subject: [PATCH 514/712] [SPARK-20392][SQL][FOLLOWUP] should not add extra AnalysisBarrier ## What changes were proposed in this pull request? I found this problem while auditing the analyzer code. It's dangerous to introduce extra `AnalysisBarrer` during analysis, as the plan inside it will bypass all analysis afterward, which may not be expected. We should only preserve `AnalysisBarrer` but not introduce new ones. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20094 from cloud-fan/barrier. --- .../sql/catalyst/analysis/Analyzer.scala | 191 ++++++++---------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 2 files changed, 84 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 10b237fb22b96..7f2128ed36a08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -665,14 +664,18 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { - // Remove analysis barrier if any. - val right = EliminateBarriers(originalRight) + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") right.collect { + // For `AnalysisBarrier`, recursively de-duplicate its child. + case oldVersion: AnalysisBarrier + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = dedupRight(left, oldVersion.child) + (oldVersion, AnalysisBarrier(newVersion)) + // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -710,10 +713,10 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - originalRight + right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { + right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { @@ -723,7 +726,6 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - AnalysisBarrier(newRight) } } @@ -958,7 +960,8 @@ class Analyzer( protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, - throws: Boolean = false) = { + throws: Boolean = false): Expression = { + if (expr.resolved) return expr // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. @@ -1079,100 +1082,74 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) - val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) - } else if (newOrder != order) { - s.copy(order = newOrder) - } else { - s - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - // in Sort - case ae: AnalysisException => s + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) + val ordering = newOrder.map(_.asInstanceOf[SortOrder]) + if (child.output == newChild.output) { + s.copy(order = ordering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = ordering, child = newChild) + Project(child.output, newSort) } - case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newCond = resolveExpressionRecursively(cond, child) - val requiredAttrs = newCond.references.filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away. - Project(child.output, - Filter(newCond, addMissingAttr(child, missingAttrs))) - } else if (newCond != cond) { - f.copy(condition = newCond) - } else { - f - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - case ae: AnalysisException => f + case f @ Filter(cond, child) if !f.resolved && child.resolved => + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) + if (child.output == newChild.output) { + f.copy(condition = newCond.head) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(newCond.head, newChild) + Project(child.output, newFilter) } } - /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ - private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { - if (missingAttrs.isEmpty) { - return AnalysisBarrier(plan) - } - plan match { - case p: Project => - val missing = missingAttrs -- p.child.outputSet - Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case a: Aggregate => - // all the missing attributes should be grouping expressions - // TODO: push down AggregateExpression - missingAttrs.foreach { attr => - if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { - throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") - } - } - val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case g: Generate => - // If join is false, we will convert it to true for getting from the child the missing - // attributes that its child might have or could have. - val missing = missingAttrs -- g.child.outputSet - g.copy(join = true, child = addMissingAttr(g.child, missing)) - case d: Distinct => - throw new AnalysisException(s"Can't add $missingAttrs to $d") - case u: UnaryNode => - u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) - case other => - throw new AnalysisException(s"Can't add $missingAttrs to $other") - } - } - - /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ - @tailrec - private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { - val resolved = resolveExpression(expr, plan) - if (resolved.resolved) { - resolved + private def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + if (exprs.forall(_.resolved)) { + // All given expressions are resolved, no need to continue anymore. + (exprs, plan) } else { plan match { - case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => - resolveExpressionRecursively(resolved, u.child) - case other => resolved + // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via + // its child. + case barrier: AnalysisBarrier => + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) + (newExprs, AnalysisBarrier(newChild)) + + case p: Project => + val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(join = true, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpression(_, other)), other) } } } @@ -1404,18 +1381,16 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) - case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + case Filter(cond, AnalysisBarrier(agg: Aggregate)) => + apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")() :: Nil, + Alias(cond, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1436,7 +1411,7 @@ class Analyzer( // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !agg.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1450,22 +1425,22 @@ class Analyzer( // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(aggregate.output, + Project(agg.output, Filter(transformedAggregateFilter, - aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) + agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) } else { - filter + f } } else { - filter + f } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. - case ae: AnalysisException => filter + case ae: AnalysisException => f } - case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a516646..07ae3ae945848 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1562,7 +1562,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("multi-insert with lateral view") { - withTempView("t1") { + withTempView("source") { spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") .createOrReplaceTempView("source") From 28778174208664327b75915e83ae5e611360eef3 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Thu, 28 Dec 2017 21:49:37 +0800 Subject: [PATCH 515/712] [SPARK-22917][SQL] Should not try to generate histogram for empty/null columns ## What changes were proposed in this pull request? For empty/null column, the result of `ApproximatePercentile` is null. Then in `ApproxCountDistinctForIntervals`, a `MatchError` (for `endpoints`) will be thrown if we try to generate histogram for that column. Besides, there is no need to generate histogram for such column. In this patch, we exclude such column when generating histogram. ## How was this patch tested? Enhanced test cases for empty/null columns. Author: Zhenhua Wang Closes #20102 from wzhfy/no_record_hgm_bug. --- .../command/AnalyzeColumnCommand.scala | 7 ++++++- .../spark/sql/StatisticsCollectionSuite.scala | 21 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index e3bb4d357b395..1122522ccb4cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -143,7 +143,12 @@ case class AnalyzeColumnCommand( val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExprs, relation)) .executedPlan.executeTake(1).head attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) => - attributePercentiles += attr -> percentilesRow.getArray(i) + val percentiles = percentilesRow.getArray(i) + // When there is no non-null value, `percentiles` is null. In such case, there is no + // need to generate histogram. + if (percentiles != null) { + attributePercentiles += attr -> percentiles + } } } AttributeMap(attributePercentiles.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index fba5d2652d3f5..b11e798532056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -85,13 +85,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("analyze empty table") { val table = "emptyTable" withTable(table) { - sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + val df = Seq.empty[Int].toDF("key") + df.write.format("json").saveAsTable(table) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) assert(fetchedStats1.get.sizeInBytes == 0) sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) assert(fetchedStats2.get.sizeInBytes == 0) + + val expectedColStat = + "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + + // There won't be histogram for empty column. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStat)) + } + } } } @@ -178,7 +189,13 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val expectedColStats = dataTypes.map { case (tpe, idx) => (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) } - checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + + // There won't be histograms for null columns. + Seq("true", "false").foreach { histogramEnabled => + withSQLConf(SQLConf.HISTOGRAM_ENABLED.key -> histogramEnabled) { + checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + } + } } test("number format in statistics") { From 5536f3181c1e77c70f01d6417407d218ea48b961 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 28 Dec 2017 09:43:50 -0600 Subject: [PATCH 516/712] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up a few Java linter errors for Apache Spark 2.3 release. ## How was this patch tested? ```bash $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` We can see the result from [Travis CI](https://travis-ci.org/dongjoon-hyun/spark/builds/322470787), too. Author: Dongjoon Hyun Closes #20101 from dongjoon-hyun/fix-java-lint. --- .../org/apache/spark/memory/MemoryConsumer.java | 3 ++- .../streaming/kinesis/KinesisInitialPositions.java | 14 ++++++++------ .../parquet/VectorizedColumnReader.java | 3 +-- .../parquet/VectorizedParquetRecordReader.java | 4 ++-- .../sql/execution/vectorized/ColumnarRow.java | 3 ++- .../spark/sql/sources/v2/SessionConfigSupport.java | 3 --- .../v2/streaming/ContinuousReadSupport.java | 5 ++++- .../v2/streaming/ContinuousWriteSupport.java | 6 +++--- .../sql/sources/v2/streaming/reader/Offset.java | 3 ++- .../v2/streaming/reader/PartitionOffset.java | 1 - .../hive/service/cli/operation/SQLOperation.java | 1 - 11 files changed, 24 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index a7bd4b3799a25..115e1fbb79a2e 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -154,6 +154,7 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); - throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + + got); } } diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java index 206e1e4699030..b5f5ab0e90540 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java @@ -67,9 +67,10 @@ public Date getTimestamp() { /** - * Returns instance of [[KinesisInitialPosition]] based on the passed [[InitialPositionInStream]]. - * This method is used in KinesisUtils for translating the InitialPositionInStream - * to InitialPosition. This function would be removed when we deprecate the KinesisUtils. + * Returns instance of [[KinesisInitialPosition]] based on the passed + * [[InitialPositionInStream]]. This method is used in KinesisUtils for translating the + * InitialPositionInStream to InitialPosition. This function would be removed when we deprecate + * the KinesisUtils. * * @return [[InitialPosition]] */ @@ -83,9 +84,10 @@ public static KinesisInitialPosition fromKinesisInitialPosition( // InitialPositionInStream.AT_TIMESTAMP is not supported. // Use InitialPosition.atTimestamp(timestamp) instead. throw new UnsupportedOperationException( - "Only InitialPositionInStream.LATEST and InitialPositionInStream.TRIM_HORIZON " + - "supported in initialPositionInStream(). Please use the initialPosition() from " + - "builder API in KinesisInputDStream for using InitialPositionInStream.AT_TIMESTAMP"); + "Only InitialPositionInStream.LATEST and InitialPositionInStream." + + "TRIM_HORIZON supported in initialPositionInStream(). Please use " + + "the initialPosition() from builder API in KinesisInputDStream for " + + "using InitialPositionInStream.AT_TIMESTAMP"); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3ba180860c325..c120863152a96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -31,7 +31,6 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; -import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -96,7 +95,7 @@ public class VectorizedColumnReader { private final OriginalType originalType; // The timezone conversion to apply to int96 timestamps. Null if no conversion. private final TimeZone convertTz; - private final static TimeZone UTC = DateTimeUtils.TimeZoneUTC(); + private static final TimeZone UTC = DateTimeUtils.TimeZoneUTC(); public VectorizedColumnReader( ColumnDescriptor descriptor, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 14f2a58d638f0..6c157e85d411f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -79,8 +79,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private boolean[] missingColumns; /** - * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to workaround - * incompatibilities between different engines when writing timestamp values. + * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to + * workaround incompatibilities between different engines when writing timestamp values. */ private TimeZone convertTz = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java index 95c0d09873d67..8bb33ed5b78c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -28,7 +28,8 @@ * to be reused, callers should copy the data out if it needs to be stored. */ public final class ColumnarRow extends InternalRow { - // The data for this row. E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // The data for this row. + // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 0b5b6ac675f2c..3cb020d2e0836 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -19,9 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; -import java.util.List; -import java.util.Map; - /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 8837bae6156b1..3136cee1f655f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -39,5 +39,8 @@ public interface ContinuousReadSupport extends DataSourceV2 { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - ContinuousReader createContinuousReader(Optional schema, String checkpointLocation, DataSourceV2Options options); + ContinuousReader createContinuousReader( + Optional schema, + String checkpointLocation, + DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java index ec15e436672b4..dee493cadb71e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java @@ -39,9 +39,9 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link DataSourceV2Writer} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 517fdab16684c..60b87f2ac0756 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -42,7 +42,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of @Override public boolean equals(Object obj) { if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { - return this.json().equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); + return this.json() + .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index 729a6123034f0..eca0085c8a8ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -26,5 +26,4 @@ * These offsets must be serializable. */ public interface PartitionOffset extends Serializable { - } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index fd9108eb53ca9..70c27948de61b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -42,7 +42,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; -import org.apache.hadoop.hive.ql.session.OperationLog; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.SerDe; From 8f6d5734d504be89249fb6744ef4067be2e889fc Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 28 Dec 2017 11:44:06 -0600 Subject: [PATCH 517/712] [SPARK-22875][BUILD] Assembly build fails for a high user id ## What changes were proposed in this pull request? Add tarLongFileMode=posix configuration for the assembly plugin ## How was this patch tested? Reran build successfully ``` ./build/mvn package -Pbigtop-dist -DskipTests -rf :spark-assembly_2.11 [INFO] Spark Project Assembly ............................. SUCCESS [ 23.082 s] ``` Author: Gera Shegalov Closes #20055 from gerashegalov/gera/tarLongFileMode. --- pom.xml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pom.xml b/pom.xml index 92f897095f087..34793b0c41708 100644 --- a/pom.xml +++ b/pom.xml @@ -2295,6 +2295,9 @@ org.apache.maven.plugins maven-assembly-plugin 3.1.0 + + posix + org.apache.maven.plugins From 9c21ece35edc175edde949175a4ce701679824c8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 28 Dec 2017 15:41:16 -0600 Subject: [PATCH 518/712] [SPARK-22836][UI] Show driver logs in UI when available. Port code from the old executors listener to the new one, so that the driver logs present in the application start event are kept. Author: Marcelo Vanzin Closes #20038 from vanzin/SPARK-22836. --- .../spark/status/AppStatusListener.scala | 11 +++++++++++ .../spark/status/AppStatusListenerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 5253297137323..487a782e865e8 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -119,6 +119,17 @@ private[spark] class AppStatusListener( kvstore.write(new ApplicationInfoWrapper(appInfo)) kvstore.write(appSummary) + + // Update the driver block manager with logs from this event. The SparkContext initialization + // code registers the driver before this event is sent. + event.driverLogs.foreach { logs => + val driver = liveExecutors.get(SparkContext.DRIVER_IDENTIFIER) + .orElse(liveExecutors.get(SparkContext.LEGACY_DRIVER_IDENTIFIER)) + driver.foreach { d => + d.executorLogs = logs.toMap + update(d, System.nanoTime()) + } + } } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index c0b3a79fe981e..997c7de8dd02b 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -942,6 +942,24 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("driver logs") { + val listener = new AppStatusListener(store, conf, true) + + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(time, driver, 42L)) + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + Some(Map("stdout" -> "file.txt")))) + + check[ExecutorSummaryWrapper](SparkContext.DRIVER_IDENTIFIER) { d => + assert(d.info.executorLogs("stdout") === "file.txt") + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { From 613b71a1230723ed556239b8b387722043c67252 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 29 Dec 2017 06:58:38 +0800 Subject: [PATCH 519/712] [SPARK-22890][TEST] Basic tests for DateTimeOperations ## What changes were proposed in this pull request? Test Coverage for `DateTimeOperations`, this is a Sub-tasks for [SPARK-22722](https://issues.apache.org/jira/browse/SPARK-22722). ## How was this patch tested? N/A Author: Yuming Wang Closes #20061 from wangyum/SPARK-22890. --- .../native/dateTimeOperations.sql | 60 +++ .../native/dateTimeOperations.sql.out | 349 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +- 3 files changed, 410 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql new file mode 100644 index 0000000000000..1e98221867965 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/dateTimeOperations.sql @@ -0,0 +1,60 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE TEMPORARY VIEW t AS SELECT 1; + +select cast(1 as tinyint) + interval 2 day; +select cast(1 as smallint) + interval 2 day; +select cast(1 as int) + interval 2 day; +select cast(1 as bigint) + interval 2 day; +select cast(1 as float) + interval 2 day; +select cast(1 as double) + interval 2 day; +select cast(1 as decimal(10, 0)) + interval 2 day; +select cast('2017-12-11' as string) + interval 2 day; +select cast('2017-12-11 09:30:00' as string) + interval 2 day; +select cast('1' as binary) + interval 2 day; +select cast(1 as boolean) + interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day; +select cast('2017-12-11 09:30:00' as date) + interval 2 day; + +select interval 2 day + cast(1 as tinyint); +select interval 2 day + cast(1 as smallint); +select interval 2 day + cast(1 as int); +select interval 2 day + cast(1 as bigint); +select interval 2 day + cast(1 as float); +select interval 2 day + cast(1 as double); +select interval 2 day + cast(1 as decimal(10, 0)); +select interval 2 day + cast('2017-12-11' as string); +select interval 2 day + cast('2017-12-11 09:30:00' as string); +select interval 2 day + cast('1' as binary); +select interval 2 day + cast(1 as boolean); +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp); +select interval 2 day + cast('2017-12-11 09:30:00' as date); + +select cast(1 as tinyint) - interval 2 day; +select cast(1 as smallint) - interval 2 day; +select cast(1 as int) - interval 2 day; +select cast(1 as bigint) - interval 2 day; +select cast(1 as float) - interval 2 day; +select cast(1 as double) - interval 2 day; +select cast(1 as decimal(10, 0)) - interval 2 day; +select cast('2017-12-11' as string) - interval 2 day; +select cast('2017-12-11 09:30:00' as string) - interval 2 day; +select cast('1' as binary) - interval 2 day; +select cast(1 as boolean) - interval 2 day; +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day; +select cast('2017-12-11 09:30:00' as date) - interval 2 day; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out new file mode 100644 index 0000000000000..12c1d1617679f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/dateTimeOperations.sql.out @@ -0,0 +1,349 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 40 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS SELECT 1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select cast(1 as tinyint) + interval 2 day +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) + interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 2 +select cast(1 as smallint) + interval 2 day +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) + interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 3 +select cast(1 as int) + interval 2 day +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) + interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 4 +select cast(1 as bigint) + interval 2 day +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) + interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 5 +select cast(1 as float) + interval 2 day +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) + interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 6 +select cast(1 as double) + interval 2 day +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) + interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 7 +select cast(1 as decimal(10, 0)) + interval 2 day +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) + interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 8 +select cast('2017-12-11' as string) + interval 2 day +-- !query 8 schema +struct +-- !query 8 output +2017-12-13 00:00:00 + + +-- !query 9 +select cast('2017-12-11 09:30:00' as string) + interval 2 day +-- !query 9 schema +struct +-- !query 9 output +2017-12-13 09:30:00 + + +-- !query 10 +select cast('1' as binary) + interval 2 day +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) + interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) + interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 11 +select cast(1 as boolean) + interval 2 day +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) + interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) + interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 12 +select cast('2017-12-11 09:30:00.0' as timestamp) + interval 2 day +-- !query 12 schema +struct +-- !query 12 output +2017-12-13 09:30:00 + + +-- !query 13 +select cast('2017-12-11 09:30:00' as date) + interval 2 day +-- !query 13 schema +struct +-- !query 13 output +2017-12-13 + + +-- !query 14 +select interval 2 day + cast(1 as tinyint) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS TINYINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS TINYINT))' (calendarinterval and tinyint).; line 1 pos 7 + + +-- !query 15 +select interval 2 day + cast(1 as smallint) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS SMALLINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS SMALLINT))' (calendarinterval and smallint).; line 1 pos 7 + + +-- !query 16 +select interval 2 day + cast(1 as int) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS INT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS INT))' (calendarinterval and int).; line 1 pos 7 + + +-- !query 17 +select interval 2 day + cast(1 as bigint) +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BIGINT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BIGINT))' (calendarinterval and bigint).; line 1 pos 7 + + +-- !query 18 +select interval 2 day + cast(1 as float) +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS FLOAT))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS FLOAT))' (calendarinterval and float).; line 1 pos 7 + + +-- !query 19 +select interval 2 day + cast(1 as double) +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DOUBLE))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DOUBLE))' (calendarinterval and double).; line 1 pos 7 + + +-- !query 20 +select interval 2 day + cast(1 as decimal(10, 0)) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS DECIMAL(10,0)))' (calendarinterval and decimal(10,0)).; line 1 pos 7 + + +-- !query 21 +select interval 2 day + cast('2017-12-11' as string) +-- !query 21 schema +struct +-- !query 21 output +2017-12-13 00:00:00 + + +-- !query 22 +select interval 2 day + cast('2017-12-11 09:30:00' as string) +-- !query 22 schema +struct +-- !query 22 output +2017-12-13 09:30:00 + + +-- !query 23 +select interval 2 day + cast('1' as binary) +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST('1' AS BINARY))' due to data type mismatch: differing types in '(interval 2 days + CAST('1' AS BINARY))' (calendarinterval and binary).; line 1 pos 7 + + +-- !query 24 +select interval 2 day + cast(1 as boolean) +-- !query 24 schema +struct<> +-- !query 24 output +org.apache.spark.sql.AnalysisException +cannot resolve '(interval 2 days + CAST(1 AS BOOLEAN))' due to data type mismatch: differing types in '(interval 2 days + CAST(1 AS BOOLEAN))' (calendarinterval and boolean).; line 1 pos 7 + + +-- !query 25 +select interval 2 day + cast('2017-12-11 09:30:00.0' as timestamp) +-- !query 25 schema +struct +-- !query 25 output +2017-12-13 09:30:00 + + +-- !query 26 +select interval 2 day + cast('2017-12-11 09:30:00' as date) +-- !query 26 schema +struct +-- !query 26 output +2017-12-13 + + +-- !query 27 +select cast(1 as tinyint) - interval 2 day +-- !query 27 schema +struct<> +-- !query 27 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS TINYINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS TINYINT) - interval 2 days)' (tinyint and calendarinterval).; line 1 pos 7 + + +-- !query 28 +select cast(1 as smallint) - interval 2 day +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS SMALLINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS SMALLINT) - interval 2 days)' (smallint and calendarinterval).; line 1 pos 7 + + +-- !query 29 +select cast(1 as int) - interval 2 day +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS INT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS INT) - interval 2 days)' (int and calendarinterval).; line 1 pos 7 + + +-- !query 30 +select cast(1 as bigint) - interval 2 day +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BIGINT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BIGINT) - interval 2 days)' (bigint and calendarinterval).; line 1 pos 7 + + +-- !query 31 +select cast(1 as float) - interval 2 day +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS FLOAT) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS FLOAT) - interval 2 days)' (float and calendarinterval).; line 1 pos 7 + + +-- !query 32 +select cast(1 as double) - interval 2 day +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DOUBLE) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DOUBLE) - interval 2 days)' (double and calendarinterval).; line 1 pos 7 + + +-- !query 33 +select cast(1 as decimal(10, 0)) - interval 2 day +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS DECIMAL(10,0)) - interval 2 days)' (decimal(10,0) and calendarinterval).; line 1 pos 7 + + +-- !query 34 +select cast('2017-12-11' as string) - interval 2 day +-- !query 34 schema +struct +-- !query 34 output +2017-12-09 00:00:00 + + +-- !query 35 +select cast('2017-12-11 09:30:00' as string) - interval 2 day +-- !query 35 schema +struct +-- !query 35 output +2017-12-09 09:30:00 + + +-- !query 36 +select cast('1' as binary) - interval 2 day +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST('1' AS BINARY) - interval 2 days)' due to data type mismatch: differing types in '(CAST('1' AS BINARY) - interval 2 days)' (binary and calendarinterval).; line 1 pos 7 + + +-- !query 37 +select cast(1 as boolean) - interval 2 day +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '(CAST(1 AS BOOLEAN) - interval 2 days)' due to data type mismatch: differing types in '(CAST(1 AS BOOLEAN) - interval 2 days)' (boolean and calendarinterval).; line 1 pos 7 + + +-- !query 38 +select cast('2017-12-11 09:30:00.0' as timestamp) - interval 2 day +-- !query 38 schema +struct +-- !query 38 output +2017-12-09 09:30:00 + + +-- !query 39 +select cast('2017-12-11 09:30:00' as date) - interval 2 day +-- !query 39 schema +struct +-- !query 39 output +2017-12-09 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1972dec7b86ce..5e077285ade55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2760,17 +2760,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-22894: DateTimeOperations should accept SQL like string type") { - val date = "2017-12-24" - val str = sql(s"SELECT CAST('$date' as STRING) + interval 2 months 2 seconds") - val dt = sql(s"SELECT CAST('$date' as DATE) + interval 2 months 2 seconds") - val ts = sql(s"SELECT CAST('$date' as TIMESTAMP) + interval 2 months 2 seconds") - - checkAnswer(str, Row("2018-02-24 00:00:02") :: Nil) - checkAnswer(dt, Row(Date.valueOf("2018-02-24")) :: Nil) - checkAnswer(ts, Row(Timestamp.valueOf("2018-02-24 00:00:02")) :: Nil) - } - // Only New OrcFileFormat supports this Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, "parquet").foreach { format => From cfcd746689c2b84824745fa6d327ffb584c7a17d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 28 Dec 2017 17:00:49 -0600 Subject: [PATCH 520/712] [SPARK-11035][CORE] Add in-process Spark app launcher. This change adds a new launcher that allows applications to be run in a separate thread in the same process as the calling code. To achieve that, some code from the child process implementation was moved to abstract classes that implement the common functionality, and the new launcher inherits from those. The new launcher was added as a new class, instead of implemented as a new option to the existing SparkLauncher, to avoid ambigous APIs. For example, SparkLauncher has ways to set the child app's environment, modify SPARK_HOME, or control the logging of the child process, none of which apply to in-process apps. The in-process launcher has limitations: it needs Spark in the context class loader of the calling thread, and it's bound by Spark's current limitation of a single client-mode application per JVM. It also relies on the recently added SparkApplication trait to make sure different apps don't mess up each other's configuration, so config isolation is currently limited to cluster mode. I also chose to keep the same socket-based communication for in-process apps, even though it might be possible to avoid it for in-process mode. That helps both implementations share more code. Tested with new and existing unit tests, and with a simple app that uses the launcher; also made sure the app ran fine with older launcher jar to check binary compatibility. Author: Marcelo Vanzin Closes #19591 from vanzin/SPARK-11035. --- core/pom.xml | 8 + .../spark/launcher/LauncherBackend.scala | 11 +- .../cluster/StandaloneSchedulerBackend.scala | 1 + .../local/LocalSchedulerBackend.scala | 1 + .../spark/launcher/SparkLauncherSuite.java | 70 +++- .../spark/launcher/AbstractAppHandle.java | 129 +++++++ .../spark/launcher/AbstractLauncher.java | 307 +++++++++++++++ .../spark/launcher/ChildProcAppHandle.java | 117 +----- .../spark/launcher/InProcessAppHandle.java | 83 +++++ .../spark/launcher/InProcessLauncher.java | 110 ++++++ .../spark/launcher/LauncherProtocol.java | 6 + .../apache/spark/launcher/LauncherServer.java | 113 +++--- .../apache/spark/launcher/SparkLauncher.java | 351 +++++------------- .../launcher/SparkSubmitCommandBuilder.java | 2 +- .../apache/spark/launcher/package-info.java | 28 +- .../org/apache/spark/launcher/BaseSuite.java | 35 +- .../launcher/ChildProcAppHandleSuite.java | 16 - .../launcher/InProcessLauncherSuite.java | 170 +++++++++ .../spark/launcher/LauncherServerSuite.java | 82 ++-- .../MesosCoarseGrainedSchedulerBackend.scala | 2 + .../org/apache/spark/deploy/yarn/Client.scala | 2 + 21 files changed, 1139 insertions(+), 505 deletions(-) create mode 100644 launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java create mode 100644 launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java create mode 100644 launcher/src/test/java/org/apache/spark/launcher/InProcessLauncherSuite.java diff --git a/core/pom.xml b/core/pom.xml index fa138d3e7a4e0..0a5bd958fc9c5 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -351,6 +351,14 @@ spark-tags_${scala.binary.version}
    + + org.apache.spark + spark-launcher_${scala.binary.version} + ${project.version} + tests + test + + 0.10.2 - 4.5.2 - 4.4.4 + 4.5.4 + 4.4.8 3.1 3.4.1 From ea0a5eef2238daa68a15b60a6f1a74c361216140 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 31 Dec 2017 02:50:00 +0900 Subject: [PATCH 543/712] [SPARK-22924][SPARKR] R API for sortWithinPartitions ## What changes were proposed in this pull request? Add to `arrange` the option to sort only within partition ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20118 from felixcheung/rsortwithinpartition. --- R/pkg/R/DataFrame.R | 14 ++++++++++---- R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ace49daf9cd83..fe238f6dd4eb0 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... additional sorting fields #' @param decreasing a logical argument indicating sorting order for columns when #' a character vector is specified for col +#' @param withinPartitions a logical argument indicating whether to sort only within each partition #' @return A SparkDataFrame where all elements are sorted. #' @family SparkDataFrame functions #' @aliases arrange,SparkDataFrame,Column-method @@ -2312,16 +2313,21 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) #' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) +#' arrange(df, "col1", "col2", withinPartitions = TRUE) #' } #' @note arrange(SparkDataFrame, Column) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "Column"), - function(x, col, ...) { + function(x, col, ..., withinPartitions = FALSE) { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", jcols) + if (withinPartitions) { + sdf <- callJMethod(x@sdf, "sortWithinPartitions", jcols) + } else { + sdf <- callJMethod(x@sdf, "sort", jcols) + } dataFrame(sdf) }) @@ -2332,7 +2338,7 @@ setMethod("arrange", #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), - function(x, col, ..., decreasing = FALSE) { + function(x, col, ..., decreasing = FALSE, withinPartitions = FALSE) { # all sorting columns by <- list(col, ...) @@ -2356,7 +2362,7 @@ setMethod("arrange", } }) - do.call("arrange", c(x, jcols)) + do.call("arrange", c(x, jcols, withinPartitions = withinPartitions)) }) #' @rdname arrange diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 1b7d53fd73a08..5197838eaac66 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2130,6 +2130,11 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted7 <- arrange(df, "name", decreasing = FALSE) expect_equal(collect(sorted7)[2, "age"], 19) + + df <- createDataFrame(cars, numPartitions = 10) + expect_equal(getNumPartitions(df), 10) + sorted8 <- arrange(df, "dist", withinPartitions = TRUE) + expect_equal(collect(sorted8)[5:6, "dist"], c(22, 10)) }) test_that("filter() on a DataFrame", { From ee3af15fea18356a9223d61cfe6aaa98ab4dc733 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 31 Dec 2017 14:47:23 +0800 Subject: [PATCH 544/712] [SPARK-22363][SQL][TEST] Add unit test for Window spilling ## What changes were proposed in this pull request? There is already test using window spilling, but the test coverage is not ideal. In this PR the already existing test was fixed and additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20022 from gaborgsomogyi/SPARK-22363. --- .../sql/DataFrameWindowFunctionsSuite.scala | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index ea725af8d1ad8..01c988ecc3726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -518,9 +519,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + test("Window spill with less than the inMemoryThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "2", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold but less than the spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") { + assertNotSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + + test("Window spill with more than the inMemoryThreshold and spillThreshold") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { + assertSpilled(sparkContext, "select") { + df.select($"key", sum("value").over(window)).collect() + } + } + } + test("SPARK-21258: complex object in combination with spilling") { // Make sure we trigger the spilling path. - withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1", + SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { val sampleSchema = new StructType(). add("f0", StringType). add("f1", LongType). @@ -558,7 +596,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ - spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + assertSpilled(sparkContext, "select") { + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } } } } From cfbe11e8164c04cd7d388e4faeded21a9331dac4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 31 Dec 2017 15:06:54 +0800 Subject: [PATCH 545/712] [SPARK-22895][SQL] Push down the deterministic predicates that are after the first non-deterministic ## What changes were proposed in this pull request? Currently, we do not guarantee an order evaluation of conjuncts in either Filter or Join operator. This is also true to the mainstream RDBMS vendors like DB2 and MS SQL Server. Thus, we should also push down the deterministic predicates that are after the first non-deterministic, if possible. ## How was this patch tested? Updated the existing test cases. Author: gatorsmile Closes #20069 from gatorsmile/morePushDown. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++----------- .../optimizer/FilterPushdownSuite.scala | 33 +++++++-------- .../v2/PushDownOperatorsToDataSource.scala | 10 ++--- .../execution/python/ExtractPythonUDFs.scala | 6 +-- .../StreamingSymmetricHashJoinHelper.scala | 5 +-- .../python/BatchEvalPythonExecSuite.scala | 10 +++-- ...treamingSymmetricHashJoinHelperSuite.scala | 14 +++---- 8 files changed, 54 insertions(+), 65 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4b5f56c44444d..dc3e384008d27 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1636,6 +1636,7 @@ options. - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. + - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index eeb1b130c578b..0d4b02c6e7d8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -805,15 +805,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -835,14 +835,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(partitionAttrs) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) @@ -854,7 +854,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic) if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) @@ -878,13 +878,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } case filter @ Filter(condition, watermark: EventTimeWatermark) => - // We can only push deterministic predicates which don't reference the watermark attribute. - // We could in theory span() only on determinism and pull out deterministic predicates - // on the watermark separately. But it seems unnecessary and a bit confusing to not simply - // use the prefix as we do for nondeterminism in other cases. - - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span( - p => p.deterministic && !p.references.contains(watermark.eventTime)) + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p => + p.deterministic && !p.references.contains(watermark.eventTime) + } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduceLeft(And) @@ -925,14 +921,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // come from grandchild. // TODO: non-deterministic predicates could be pushed through some operators that do not change // the rows. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(grandchild.outputSet) } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) { val newChild = insertFilter(pushDown.reduceLeft(And)) @@ -975,23 +971,19 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions or filter predicates (on a given join's output) into three * categories based on the attributes required to evaluate them. Note that we explicitly exclude - * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or + * non-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or * canEvaluateInRight to prevent pushing these predicates on either side of the join. * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { - // Note: In order to ensure correctness, it's important to not change the relative ordering of - // any deterministic expression that follows a non-deterministic expression. To achieve this, - // we only consider pushing down those expressions that precede the first non-deterministic - // expression in the condition. - val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) + val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index a9c2306e37fd5..85a5e979f6021 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -831,9 +831,9 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( - testRelation.where('a === 2L), - testRelation2.where('d === 2L))) - .where('b + Rand(10).as("rnd") === 3 && 'c > 5L) + testRelation.where('a === 2L && 'c > 5L), + testRelation2.where('d === 2L && 'f > 5L))) + .where('b + Rand(10).as("rnd") === 3) .analyze comparePlans(optimized, correctAnswer) @@ -1134,12 +1134,13 @@ class FilterPushdownSuite extends PlanTest { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - // Verify that all conditions preceding the first non-deterministic condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 && "x.a".attr === Rand(10) && "y.b".attr === 5)) - val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), - condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) + val correctAnswer = + x.where("x.a".attr === 5).join(y.where("y.a".attr === 5 && "y.b".attr === 5), + condition = Some("x.a".attr === Rand(10))) // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. // TODO support nondeterministic expressions in join condition. @@ -1147,16 +1148,16 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: no pushdown on watermark attribute") { + test("watermark pushdown: no pushdown on watermark attribute #1") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('b, interval, testRelation) .where('a === 5 && 'b === 10 && 'c === 5) val correctAnswer = EventTimeWatermark( - 'b, interval, testRelation.where('a === 5)) - .where('b === 10 && 'c === 5) + 'b, interval, testRelation.where('a === 5 && 'c === 5)) + .where('b === 10) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1165,7 +1166,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown for nondeterministic filter") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === Rand(10) && 'c === 5) @@ -1180,7 +1181,7 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: full pushdown") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down + // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. val originalQuery = EventTimeWatermark('c, interval, testRelation) .where('a === 5 && 'b === 10) @@ -1191,15 +1192,15 @@ class FilterPushdownSuite extends PlanTest { checkAnalysis = false) } - test("watermark pushdown: empty pushdown") { + test("watermark pushdown: no pushdown on watermark attribute #2") { val interval = new CalendarInterval(2, 2000L) - // Verify that all conditions preceding the first watermark touching condition are pushed down - // by the optimizer and others are not. val originalQuery = EventTimeWatermark('a, interval, testRelation) .where('a === 5 && 'b === 10) + val correctAnswer = EventTimeWatermark( + 'a, interval, testRelation.where('b === 10)).where('a === 5) - comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze, + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 0c1708131ae46..df034adf1e7d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -40,12 +40,8 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - // Non-deterministic expressions are stateful and we must keep the input sequence unchanged - // to avoid changing the result. This means, we can't evaluate the filter conditions that - // are after the first non-deterministic condition ahead. Here we only try to push down - // deterministic conditions that are before the first non-deterministic condition. - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => @@ -74,7 +70,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel case _ => candidates } - val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) if (withFilter.output == fields) { withFilter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index f5a4cbc4793e3..2f53fe788c7d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -202,12 +202,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def trySplitFilter(plan: SparkPlan): SparkPlan = { plan match { case filter: FilterExec => - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (candidates, nonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) + FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 167e991ca62f8..217e98a6419e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -72,8 +72,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { * left AND right AND joined is equivalent to full. * * Note that left and right do not necessarily contain *all* conjuncts which satisfy - * their condition. Any conjuncts after the first nondeterministic one are treated as - * nondeterministic for purposes of the split. + * their condition. * * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. @@ -111,7 +110,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Span rather than partition, because nondeterministic expressions don't commute // across AND. val (deterministicConjuncts, nonDeterministicConjuncts) = - splitConjunctivePredicates(condition.get).span(_.deterministic) + splitConjunctivePredicates(condition.get).partition(_.deterministic) val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => cond.references.subsetOf(left.outputSet) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 9e4a2e8776956..d456c931f5275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -75,13 +75,17 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { assert(qualifiedPlanNodes.size == 2) } - test("Python UDF: no push down on predicates starting from the first non-deterministic") { + test("Python UDF: push down on deterministic predicates after the first non-deterministic") { val df = Seq(("Hello", 4)).toDF("a", "b") .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { - case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b } - assert(qualifiedPlanNodes.size == 1) + assert(qualifiedPlanNodes.size == 2) } test("Python UDF refers to the attributes from more than one child") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala index 2a854e37bf0df..69b7154895341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.streaming import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates import org.apache.spark.sql.types._ @@ -95,19 +93,17 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest { } test("conjuncts after nondeterministic") { - // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't - // commute across it. val predicate = - (rand() > lit(0) + (rand(9) > lit(0) && leftColA > leftColB && rightColC > rightColD && leftColA === rightColC && lit(1) === lit(1)).expr val split = JoinConditionSplitPredicates(Some(predicate), left, right) - assert(split.leftSideOnly.isEmpty) - assert(split.rightSideOnly.isEmpty) - assert(split.bothSides.contains(predicate)) + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && rand(9) > lit(0)).expr)) assert(split.full.contains(predicate)) } From 3d8837e59aadd726805371041567ceff375194c0 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 31 Dec 2017 14:39:24 +0200 Subject: [PATCH 546/712] [SPARK-22397][ML] add multiple columns support to QuantileDiscretizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? add multi columns support to QuantileDiscretizer. When calculating the splits, we can either merge together all the probabilities into one array by calculating approxQuantiles on multiple columns at once, or compute approxQuantiles separately for each column. After doing the performance comparision, we found it’s better to calculating approxQuantiles on multiple columns at once. Here is how we measuring the performance time: ``` var duration = 0.0 for (i<- 0 until 10) { val start = System.nanoTime() discretizer.fit(df) val end = System.nanoTime() duration += (end - start) / 1e9 } println(duration/10) ``` Here is the performance test result: |numCols |NumRows | compute each approxQuantiles separately|compute multiple columns approxQuantiles at one time| |--------|----------|--------------------------------|-------------------------------------------| |10 |60 |0.3623195839 |0.1626658607 | |10 |6000 |0.7537239841 |0.3869370046 | |22 |6000 |1.6497598557 |0.4767903059 | |50 |6000 |3.2268305752 |0.7217818396 | ## How was this patch tested? add UT in QuantileDiscretizerSuite to test multi columns supports Author: Huaxin Gao Closes #19715 from huaxingao/spark_22397. --- .../ml/feature/QuantileDiscretizer.scala | 120 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 265 ++++++++++++++++++ 2 files changed, 369 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 95e8830283dee..1ec5f8cb6139b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -50,10 +50,28 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + /** + * Array of number of buckets (quantiles, or categories) into which data points are grouped. + * Each value must be greater than or equal to 2 + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * + * @group param + */ + val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of number of buckets " + + "(quantiles, or categories) into which data points are grouped. This is for multiple " + + "columns input. If transforming multiple columns and numBucketsArray is not set, but " + + "numBuckets is set, then numBuckets will be applied across all columns.", + (arrayOfNumBuckets: Array[Int]) => arrayOfNumBuckets.forall(ParamValidators.gtEq(2))) + + /** @group getParam */ + def getNumBucketsArray: Array[Int] = $(numBucketsArray) + /** * Relative error (see documentation for * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for description) * Must be in the range [0, 1]. + * Note that in multiple columns case, relative error is applied to all columns. * default: 0.001 * @group param */ @@ -68,7 +86,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special - * additional bucket). + * additional bucket). Note that in the multiple columns case, the invalid handling is applied + * to all columns. That said for 'error' it will throw an error if any invalids are found in + * any column, for 'skip' it will skip rows with any invalids in any columns, etc. * Default: "error" * @group param */ @@ -86,6 +106,11 @@ private[feature] trait QuantileDiscretizerBase extends Params * categorical features. The number of bins can be set using the `numBuckets` parameter. It is * possible that the number of buckets used will be smaller than this value, for example, if there * are too few distinct values of the input to create enough distinct quantiles. + * Since 2.3.0, `QuantileDiscretizer` can map multiple columns at once by setting the `inputCols` + * parameter. If both of the `inputCol` and `inputCols` parameters are set, an Exception will be + * thrown. To specify the number of buckets for each column, the `numBucketsArray` parameter can + * be set, or if the number of buckets should be the same across columns, `numBuckets` can be + * set as a convenience. * * NaN handling: * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This @@ -104,7 +129,8 @@ private[feature] trait QuantileDiscretizerBase extends Params */ @Since("1.6.0") final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable + with HasInputCols with HasOutputCols { @Since("1.6.0") def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -129,34 +155,96 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.1.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private[feature] def getInOutCols: (Array[String], Array[String]) = { + require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || + (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), + "QuantileDiscretizer only supports setting either inputCol/outputCol or" + + "inputCols/outputCols." + ) + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "inputCols number do not match outputCols") + ($(inputCols), $(outputCols)) + } + } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - val inputFields = schema.fields - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val (inputColNames, outputColNames) = getInOutCols + val existingFields = schema.fields + var outputFields = existingFields + inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => + SchemaUtils.checkNumericType(schema, inputColName) + require(existingFields.forall(_.name != outputColName), + s"Output column ${outputColName} already exists.") + val attr = NominalAttribute.defaultAttr.withName(outputColName) + outputFields :+= attr.toStructField() + } StructType(outputFields) } @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) - val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) + if (isSet(inputCols)) { + val splitsArray = if (isSet(numBucketsArray)) { + val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => + (0.0 to 1.0 by 1.0 / numOfBuckets).toArray + } + + val probabilityArray = probArrayPerCol.flatten.sorted.distinct + val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols), + probabilityArray, $(relativeError)) + + splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) => + val probSet = probs.toSet + val idxSet = probabilityArray.zipWithIndex.collect { + case (p, idx) if probSet(p) => + idx + }.toSet + splits.zipWithIndex.collect { + case (s, idx) if idxSet(idx) => + s + } + } + } else { + dataset.stat.approxQuantile($(inputCols), + (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + } + bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) + } else { + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + bucketizer.setSplits(getDistinctSplits(splits)) + } + copyValues(bucketizer.setParent(this)) + } + + private def getDistinctSplits(splits: Array[Double]): Array[Double] = { splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val distinctSplits = splits.distinct if (splits.length != distinctSplits.length) { log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid) - .setSplits(distinctSplits.sorted) - .setHandleInvalid($(handleInvalid)) - copyValues(bucketizer.setParent(this)) + distinctSplits.sorted } @Since("1.6.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index f219f775b2186..e9a75e931e6a8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ @@ -146,4 +147,268 @@ class QuantileDiscretizerSuite val model = discretizer.fit(df) assert(model.hasParent) } + + test("Multiple Columns: Test observed number of buckets and their sizes match expected values") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 100000 + val numBuckets = 5 + val data1 = Array.range(1, 100001, 1).map(_.toDouble) + val data2 = Array.range(1, 200000, 2).map(_.toDouble) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + } + + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } + } + + test("Multiple Columns: Test on data with high proportion of duplicated values") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 5 + val expectedNumBucket = 3 + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 2.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } + } + + test("Multiple Columns: Test transform on data with NaN value") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 3 + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0) + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + + List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { + case (u, v, w) => + discretizer.setHandleInvalid(u) + val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { + case (((a, b), c), d) => (a, b, c, d) + }.toSeq.toDF("input1", "input2", "expected1", "expected2") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result1", "expected1", "result2", "expected2").collect().foreach { + case Row(x: Double, y: Double, z: Double, w: Double) => + assert(x === y && w === z) + } + } + } + + test("Multiple Columns: Test numBucketsArray") { + val spark = this.spark + import spark.implicits._ + + val numBucketsArray: Array[Int] = Array(2, 5, 10) + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val expected1 = Array (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, + 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, + 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) + } + val df = + data.toDF("input1", "input2", "input3", "expected1", "expected2", "expected3") + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(numBucketsArray) + + discretizer.fit(df).transform(df). + select("result1", "expected1", "result2", "expected2", "result3", "expected3") + .collect().foreach { + case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => + assert(r1 === e1, + s"The result value is not correct after bucketing. Expected $e1 but found $r1") + assert(r2 === e2, + s"The result value is not correct after bucketing. Expected $e2 but found $r2") + assert(r3 === e3, + s"The result value is not correct after bucketing. Expected $e3 but found $r3") + } + } + + test("Multiple Columns: Compare single/multiple column(s) QuantileDiscretizer in pipeline") { + val spark = this.spark + import spark.implicits._ + + val numBucketsArray: Array[Int] = Array(2, 5, 10) + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx)) + } + val df = + data.toDF("input1", "input2", "input3") + + val multiColsDiscretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(numBucketsArray) + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsDiscretizer)) + .fit(df) + + val discretizerForCol1 = new QuantileDiscretizer() + .setInputCol("input1") + .setOutputCol("result1") + .setNumBuckets(numBucketsArray(0)) + + val discretizerForCol2 = new QuantileDiscretizer() + .setInputCol("input2") + .setOutputCol("result2") + .setNumBuckets(numBucketsArray(1)) + + val discretizerForCol3 = new QuantileDiscretizer() + .setInputCol("input3") + .setOutputCol("result3") + .setNumBuckets(numBucketsArray(2)) + + val plForSingleCol = new Pipeline() + .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) + .fit(df) + + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "result2", "result3") + .collect() + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "result2", "result3") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && + rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) + } + } + + test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + + "explicitly with identical values") { + val spark = this.spark + import spark.implicits._ + + val data1 = Array.range(1, 21, 1).map(_.toDouble) + val data2 = Array.range(1, 40, 2).map(_.toDouble) + val data3 = Array.range(1, 60, 3).map(_.toDouble) + val data = (0 until 20).map { idx => + (data1(idx), data2(idx), data3(idx)) + } + val df = + data.toDF("input1", "input2", "input3") + + val discretizerSingleNumBuckets = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBuckets(10) + + val discretizerNumBucketsArray = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("result1", "result2", "result3")) + .setNumBucketsArray(Array(10, 10, 10)) + + val result1 = discretizerSingleNumBuckets.fit(df).transform(df) + .select("result1", "result2", "result3") + .collect() + val result2 = discretizerNumBucketsArray.fit(df).transform(df) + .select("result1", "result2", "result3") + .collect() + + result1.zip(result2).foreach { + case (row1, row2) => + assert(row1.getDouble(0) == row2.getDouble(0) && + row1.getDouble(1) == row2.getDouble(1) && + row1.getDouble(2) == row2.getDouble(2)) + } + } + + test("Multiple Columns: read/write") { + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(5, 10)) + testDefaultReadWrite(discretizer) + } + + test("Multiple Columns: Both inputCol and inputCols are set") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(3) + .setInputCols(Array("input1", "input2")) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + // When both inputCol and inputCols are set, we throw Exception. + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(3) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } } From 028ee40165315337e2a349b19731764d64e4f51d Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Sun, 31 Dec 2017 14:51:38 +0200 Subject: [PATCH 547/712] [SPARK-22801][ML][PYSPARK] Allow FeatureHasher to treat numeric columns as categorical Previously, `FeatureHasher` always treats numeric type columns as numbers and never as categorical features. It is quite common to have categorical features represented as numbers or codes in data sources. In order to hash these features as categorical, users must first explicitly convert them to strings which is cumbersome. Add a new param `categoricalCols` which specifies the numeric columns that should be treated as categorical features. ## How was this patch tested? New unit tests. Author: Nick Pentreath Closes #19991 from MLnick/hasher-num-cat. --- .../spark/ml/feature/FeatureHasher.scala | 38 +++++++++++++++---- .../spark/ml/feature/FeatureHasherSuite.scala | 25 ++++++++++++ python/pyspark/ml/feature.py | 34 +++++++++++++---- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index 4615daed20fb1..a918dd4c075da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} @@ -40,9 +40,9 @@ import org.apache.spark.util.collection.OpenHashMap * The [[FeatureHasher]] transformer operates on multiple columns. Each column may contain either * numeric or categorical features. Behavior and handling of column data types is as follows: * -Numeric columns: For numeric features, the hash value of the column name is used to map the - * feature value to its index in the feature vector. Numeric features are never - * treated as categorical, even when they are integers. You must explicitly - * convert numeric columns containing categorical features to strings first. + * feature value to its index in the feature vector. By default, numeric features + * are not treated as categorical (even when they are integers). To treat them + * as categorical, specify the relevant columns in `categoricalCols`. * -String columns: For categorical features, the hash value of the string "column_name=value" * is used to map to the vector index, with an indicator value of `1.0`. * Thus, categorical features are "one-hot" encoded @@ -86,6 +86,17 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") def this() = this(Identifiable.randomUID("featureHasher")) + /** + * Numeric columns to treat as categorical features. By default only string and boolean + * columns are treated as categorical, so this param can be used to explicitly specify the + * numerical columns to treat as categorical. Note, the relevant columns must also be set in + * `inputCols`. + * @group param + */ + @Since("2.3.0") + val categoricalCols = new StringArrayParam(this, "categoricalCols", + "numeric columns to treat as categorical") + /** * Number of features. Should be greater than 0. * (default = 2^18^) @@ -117,15 +128,28 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group getParam */ + @Since("2.3.0") + def getCategoricalCols: Array[String] = $(categoricalCols) + + /** @group setParam */ + @Since("2.3.0") + def setCategoricalCols(value: Array[String]): this.type = set(categoricalCols, value) + @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { val hashFunc: Any => Int = OldHashingTF.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) + val catCols = if (isSet(categoricalCols)) { + $(categoricalCols).toSet + } else { + Set[String]() + } val outputSchema = transformSchema(dataset.schema) val realFields = outputSchema.fields.filter { f => - f.dataType.isInstanceOf[NumericType] + f.dataType.isInstanceOf[NumericType] && !catCols.contains(f.name) }.map(_.name).toSet def getDouble(x: Any): Double = { @@ -149,8 +173,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme val hash = hashFunc(colName) (hash, value) } else { - // string and boolean values are treated as categorical, with an indicator value of 1.0 - // and vector index based on hash of "column_name=value" + // string, boolean and numeric values that are in catCols are treated as categorical, + // with an indicator value of 1.0 and vector index based on hash of "column_name=value" val value = row.get(fieldIndex).toString val fieldName = s"$colName=$value" val hash = hashFunc(fieldName) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 407371a826666..3fc3cbb62d5b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -78,6 +78,31 @@ class FeatureHasherSuite extends SparkFunSuite assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) } + test("setting explicit numerical columns to treat as categorical") { + val df = Seq( + (2.0, 1, "foo"), + (3.0, 2, "bar") + ).toDF("real", "int", "string") + + val n = 100 + val hasher = new FeatureHasher() + .setInputCols("real", "int", "string") + .setCategoricalCols(Array("real")) + .setOutputCol("features") + .setNumFeatures(n) + val output = hasher.transform(df) + + val features = output.select("features").as[Vector].collect() + // Assume perfect hash on field names + def idx: Any => Int = murmur3FeatureIdx(n) + // check expected indices + val expected = Seq( + Vectors.sparse(n, Seq((idx("real=2.0"), 1.0), (idx("int"), 1.0), (idx("string=foo"), 1.0))), + Vectors.sparse(n, Seq((idx("real=3.0"), 1.0), (idx("int"), 2.0), (idx("string=bar"), 1.0))) + ) + assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + } + test("hashing works for all numeric types") { val df = Seq(5.0, 10.0, 15.0).toDF("real") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5094324e5c1fe..13bf95cce40be 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -714,9 +714,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, * Numeric columns: For numeric features, the hash value of the column name is used to map the - feature value to its index in the feature vector. Numeric features are never - treated as categorical, even when they are integers. You must explicitly - convert numeric columns containing categorical features to strings first. + feature value to its index in the feature vector. By default, numeric features + are not treated as categorical (even when they are integers). To treat them + as categorical, specify the relevant columns in `categoricalCols`. * String columns: For categorical features, the hash value of the string "column_name=value" @@ -741,6 +741,8 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + >>> hasher.setCategoricalCols(["real"]).transform(df).head().features + SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) @@ -752,10 +754,14 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, .. versionadded:: 2.3.0 """ + categoricalCols = Param(Params._dummy(), "categoricalCols", + "numeric columns to treat as categorical", + typeConverter=TypeConverters.toListString) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) """ super(FeatureHasher, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid) @@ -765,14 +771,28 @@ def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): @keyword_only @since("2.3.0") - def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None): """ - setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None) Sets params for this FeatureHasher. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.3.0") + def setCategoricalCols(self, value): + """ + Sets the value of :py:attr:`categoricalCols`. + """ + return self._set(categoricalCols=value) + + @since("2.3.0") + def getCategoricalCols(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.categoricalCols) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, From 5955a2d0fb8f84560925d72afe36055089f44a05 Mon Sep 17 00:00:00 2001 From: Jirka Kremser Date: Sun, 31 Dec 2017 15:38:10 -0600 Subject: [PATCH 548/712] [MINOR][DOCS] s/It take/It takes/g ## What changes were proposed in this pull request? Fixing three small typos in the docs, in particular: It take a `RDD` -> It takes an `RDD` (twice) It take an `JavaRDD` -> It takes a `JavaRDD` I didn't create any Jira issue for this minor thing, I hope it's ok. ## How was this patch tested? visually by clicking on 'preview' Author: Jirka Kremser Closes #20108 from Jiri-Kremser/docs-typo. --- docs/mllib-frequent-pattern-mining.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index c9cd7cc85e754..0d3192c6b1d9c 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -41,7 +41,7 @@ We refer users to the papers for more details. [`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +It takes an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. The following @@ -60,7 +60,7 @@ Refer to the [`FPGrowth` Scala docs](api/scala/index.html#org.apache.spark.mllib [`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the FP-growth algorithm. -It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +It takes a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) that stores the frequent itemsets with their frequencies. The following @@ -79,7 +79,7 @@ Refer to the [`FPGrowth` Java docs](api/java/org/apache/spark/mllib/fpm/FPGrowth [`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +It takes an `RDD` of transactions, where each transaction is an `List` of items of a generic type. Calling `FPGrowth.train` with transactions returns an [`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. From 994065d891a23ed89a09b3f95bc3f1f986793e0d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2017 15:28:59 -0800 Subject: [PATCH 549/712] [SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator ## What changes were proposed in this pull request? This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. The `fit` method returns `OneHotEncoderModel`. Common methods between existing `OneHotEncoder` and new `OneHotEncoderEstimator`, such as transforming schema, are extracted and put into `OneHotEncoderCommon` to reduce code duplication. ### Multi-column support `OneHotEncoderEstimator` adds simpler multi-column support because it is new API and can be free from backward compatibility. ### handleInvalid Param support `OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` and `keep`. ## How was this patch tested? Added new test suite `OneHotEncoderEstimatorSuite`. Author: Liang-Chi Hsieh Closes #19527 from viirya/SPARK-13030. --- .../spark/ml/feature/OneHotEncoder.scala | 83 +-- .../ml/feature/OneHotEncoderEstimator.scala | 522 ++++++++++++++++++ .../feature/OneHotEncoderEstimatorSuite.scala | 421 ++++++++++++++ 3 files changed, 960 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index a669da183e2c8..5ab6c2dde667a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The output vectors are sparse. * * @see `StringIndexer` for converting categorical values into category indices + * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` + * will be removed in 3.0.0. */ @Since("1.4.0") +@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" + + " will be removed in 3.0.0.", "2.3.0") class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { @@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val outputColName = $(outputCol) + val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") - val inputAttr = Attribute.fromStructField(schema(inputColName)) - val outputAttrNames: Option[Array[String]] = inputAttr match { - case nominal: NominalAttribute => - if (nominal.values.isDefined) { - nominal.values - } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(_.toString)) - } else { - None - } - case binary: BinaryAttribute => - if (binary.values.isDefined) { - binary.values - } else { - Some(Array.tabulate(2)(_.toString)) - } - case _: NumericAttribute => - throw new RuntimeException( - s"The input column $inputColName cannot be numeric.") - case _ => - None // optimistic about unknown attributes - } - - val filteredOutputAttrNames = outputAttrNames.map { names => - if ($(dropLast)) { - require(names.length > 1, - s"The input column $inputColName should have at least two distinct values.") - names.dropRight(1) - } else { - names - } - } - - val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { - val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => - BinaryAttribute.defaultAttr.withName(name) - } - new AttributeGroup($(outputCol), attrs) - } else { - new AttributeGroup($(outputCol)) - } - - val outputFields = inputFields :+ outputAttrGroup.toStructField() + val outputField = OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), outputColName, $(dropLast)) + val outputFields = inputFields :+ outputField StructType(outputFields) } @@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) - val shouldDropLast = $(dropLast) - var outputAttrGroup = AttributeGroup.fromStructField( + + val outputAttrGroupFromSchema = AttributeGroup.fromStructField( transformSchema(dataset.schema)(outputColName)) - if (outputAttrGroup.size < 0) { - // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) - .treeAggregate(0.0)( - (m, x) => { - assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") - assert(x >= 0.0 && x == x.toInt, - s"Values from column $inputColName must be indices, but got $x.") - math.max(m, x) - }, - (m0, m1) => { - math.max(m0, m1) - } - ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames - val outputAttrs: Array[Attribute] = - filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) - outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + + val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0) + } else { + outputAttrGroupFromSchema } + val metadata = outputAttrGroup.toMetadata() // data transformation diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala new file mode 100644 index 0000000000000..074622d41e28d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} + +/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid + with HasInputCols with HasOutputCols { + + /** + * Param for how to handle invalid data. + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Default: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error).", + ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) + + setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + @Since("2.3.0") + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group getParam */ + @Since("2.3.0") + def getDropLast: Boolean = $(dropLast) + + protected def validateAndTransformSchema( + schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val existingFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + // Input columns must be NumericType. + inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) + + // Prepares output columns with proper attributes by examining input columns. + val inputFields = $(inputCols).map(schema(_)) + + val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputField, outputColName, dropLast, keepInvalid) + } + outputFields.foldLeft(schema) { case (newSchema, outputField) => + SchemaUtils.appendColumn(newSchema, outputField) + } + } +} + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via `dropLast`), + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + * vector. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("2.3.0") +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("oneHotEncoder")) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + } + + @Since("2.3.0") + override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + transformSchema(dataset.schema) + + // Compute the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, + keepInvalid = false) + val categorySizes = new Array[Int]($(outputCols).length) + + val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val numOfAttrs = AttributeGroup.fromStructField( + transformedSchema(outputColName)).size + if (numOfAttrs < 0) { + Some(idx) + } else { + categorySizes(idx) = numOfAttrs + None + } + } + + // Some input columns don't have attributes or their attributes don't have necessary info. + // We need to scan the data to get the number of values for each column. + if (columnToScanIndices.length > 0) { + val inputColNames = columnToScanIndices.map($(inputCols)(_)) + val outputColNames = columnToScanIndices.map($(outputCols)(_)) + + // When fitting data, we want the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, inputColNames, outputColNames, dropLast = false) + attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => + categorySizes(idx) = attrGroup.size + } + } + + val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) + copyValues(model) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra) +} + +@Since("2.3.0") +object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { + + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + @Since("2.3.0") + override def load(path: String): OneHotEncoderEstimator = super.load(path) +} + +@Since("2.3.0") +class OneHotEncoderModel private[ml] ( + @Since("2.3.0") override val uid: String, + @Since("2.3.0") val categorySizes: Array[Int]) + extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { + + import OneHotEncoderModel._ + + // Returns the category size for a given index with `dropLast` and `handleInvalid` + // taken into account. + private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + val dropLast = getDropLast + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + if (!dropLast && keepInvalid) { + // When `handleInvalid` is "keep", an extra category is added as last category + // for invalid data. + orgCategorySize + 1 + } else if (dropLast && !keepInvalid) { + // When `dropLast` is true, the last category is removed. + orgCategorySize - 1 + } else { + // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid + // data is removed. Thus, it is the same as the plain number of categories. + orgCategorySize + } + } + + private def encoder: UserDefinedFunction = { + val oneValue = Array(1.0) + val emptyValues = Array.empty[Double] + val emptyIndices = Array.empty[Int] + val dropLast = getDropLast + val handleInvalid = getHandleInvalid + val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + // The udf performed on input data. The first parameter is the input value. The second + // parameter is the index of input. + udf { (label: Double, idx: Int) => + val plainNumCategories = categorySizes(idx) + val size = configedCategorySize(plainNumCategories, idx) + + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative.") + } else if (label == size && dropLast && !keepInvalid) { + // When `dropLast` is true and `handleInvalid` is not "keep", + // the last category is removed. + Vectors.sparse(size, emptyIndices, emptyValues) + } else if (label >= plainNumCategories && keepInvalid) { + // When `handleInvalid` is "keep", encodes invalid data to last category (and removed + // if `dropLast` is true) + if (dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + Vectors.sparse(size, Array(size - 1), oneValue) + } + } else if (label < plainNumCategories) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else { + assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } + } + } + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + + require(inputColNames.length == categorySizes.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"features ${categorySizes.length} during fitting.") + + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + verifyNumOfValues(transformedSchema) + } + + /** + * If the metadata of input columns also specifies the number of categories, we need to + * compare with expected category number with `handleInvalid` and `dropLast` taken into + * account. Mismatched numbers will cause exception. + */ + private def verifyNumOfValues(schema: StructType): StructType = { + $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = $(inputCols)(idx) + val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) + + // If the input metadata specifies number of category for output column, + // comparing with expected category number with `handleInvalid` and + // `dropLast` taken into account. + if (attrGroup.attributes.nonEmpty) { + val numCategories = configedCategorySize(categorySizes(idx), idx) + require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + + s"$numCategories categorical values for input column ${inputColName}, " + + s"but the input column had metadata specifying ${attrGroup.size} values.") + } + } + schema + } + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + + val encodedColumns = (0 until $(inputCols).length).map { idx => + val inputColName = $(inputCols)(idx) + val outputColName = $(outputCols)(idx) + + val outputAttrGroupFromSchema = + AttributeGroup.fromStructField(transformedSchema(outputColName)) + + val metadata = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, + categorySizes(idx), $(dropLast), keepInvalid).toMetadata() + } else { + outputAttrGroupFromSchema.toMetadata() + } + + encoder(col(inputColName).cast(DoubleType), lit(idx)) + .as(outputColName, metadata) + } + dataset.withColumns($(outputCols), encodedColumns) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderModel = { + val copied = new OneHotEncoderModel(uid, categorySizes) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.3.0") + override def write: MLWriter = new OneHotEncoderModelWriter(this) +} + +@Since("2.3.0") +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + + private[OneHotEncoderModel] + class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { + + private case class Data(categorySizes: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.categorySizes) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { + + private val className = classOf[OneHotEncoderModel].getName + + override def load(path: String): OneHotEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("categorySizes") + .head() + val categorySizes = data.getAs[Seq[Int]](0).toArray + val model = new OneHotEncoderModel(metadata.uid, categorySizes) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.3.0") + override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader + + @Since("2.3.0") + override def load(path: String): OneHotEncoderModel = super.load(path) +} + +/** + * Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`. + */ +private[feature] object OneHotEncoderCommon { + + private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { + val inputAttr = Attribute.fromStructField(inputCol) + inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column ${inputCol.name} cannot be continuous-value.") + case _ => + None // optimistic about unknown attributes + } + } + + /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ + private def genOutputAttrGroup( + outputAttrNames: Option[Array[String]], + outputColName: String): AttributeGroup = { + outputAttrNames.map { attrNames => + val attrs: Array[Attribute] = attrNames.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup(outputColName, attrs) + }.getOrElse{ + new AttributeGroup(outputColName) + } + } + + /** + * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. + */ + def transformOutputColumnSchema( + inputCol: StructField, + outputColName: String, + dropLast: Boolean, + keepInvalid: Boolean = false): StructField = { + val outputAttrNames = genOutputAttrNames(inputCol) + val filteredOutputAttrNames = outputAttrNames.map { names => + if (dropLast && !keepInvalid) { + require(names.length > 1, + s"The input column ${inputCol.name} should have at least two distinct values.") + names.dropRight(1) + } else if (!dropLast && keepInvalid) { + names ++ Seq("invalidValues") + } else { + names + } + } + + genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() + } + + /** + * This method is called when we want to generate `AttributeGroup` from actual data for + * one-hot encoder. + */ + def getOutputAttrGroupFromData( + dataset: Dataset[_], + inputColNames: Seq[String], + outputColNames: Seq[String], + dropLast: Boolean): Seq[AttributeGroup] = { + // The RDD approach has advantage of early-stop if any values are invalid. It seems that + // DataFrame ops don't have equivalent functions. + val columns = inputColNames.map { inputColName => + col(inputColName).cast(DoubleType) + } + val numOfColumns = columns.length + + val numAttrsArray = dataset.select(columns: _*).rdd.map { row => + (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray + }.treeAggregate(new Array[Double](numOfColumns))( + (maxValues, curValues) => { + (0 until numOfColumns).foreach { idx => + val x = curValues(idx) + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") + assert(x >= 0.0 && x == x.toInt, + s"Values from column ${inputColNames(idx)} must be indices, but got $x.") + maxValues(idx) = math.max(maxValues(idx), x) + } + maxValues + }, + (m0, m1) => { + (0 until numOfColumns).foreach { idx => + m0(idx) = math.max(m0(idx), m1(idx)) + } + m0 + } + ).map(_.toInt + 1) + + outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => + createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) + } + } + + /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ + def createAttrGroupForAttrNames( + outputColName: String, + numAttrs: Int, + dropLast: Boolean, + keepInvalid: Boolean): AttributeGroup = { + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (dropLast && !keepInvalid) { + outputAttrNames.dropRight(1) + } else if (!dropLast && keepInvalid) { + outputAttrNames ++ Seq("invalidValues") + } else { + outputAttrNames + } + genOutputAttrGroup(Some(filtered), outputColName) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala new file mode 100644 index 0000000000000..1d3f845586426 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +class OneHotEncoderEstimatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new OneHotEncoderEstimator) + } + + test("OneHotEncoderEstimator dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderEstimator dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } + + test("read/write") { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + testDefaultReadWrite(encoder) + } + + test("OneHotEncoderModel read/write") { + val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.categorySizes === instance.categorySizes) + } + + test("OneHotEncoderEstimator with varying types") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val dfWithTypes = df + .withColumn("shortInput", df("input").cast(ShortType)) + .withColumn("longInput", df("input").cast(LongType)) + .withColumn("intInput", df("input").cast(IntegerType)) + .withColumn("floatInput", df("input").cast(FloatType)) + .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) + + val cols = Array("input", "shortInput", "longInput", "intInput", + "floatInput", "decimalInput") + for (col <- cols) { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array(col)) + .setOutputCols(Array("output")) + .setDropLast(false) + + val model = encoder.fit(dfWithTypes) + val encoded = model.transform(dfWithTypes) + + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), + Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("Throw error on invalid values") { + val trainingData = Seq((0, 0), (1, 1), (2, 2)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2), (1, 3)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).show + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + } + + test("Can't transform on negative input") { + val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") + val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Negative value: -1.0. Input can't be negative") + } + + test("Keep on invalid values: dropLast = false") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("Keep on invalid values: dropLast = true") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(3, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(true) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes dropLast") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected1", new VectorUDT), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + + model.setDropLast(false) + val encoded1 = model.transform(df) + encoded1.select("output", "expected1").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + + model.setDropLast(true) + val encoded2 = model.transform(df) + encoded2.select("output", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes handleInvalid") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(trainingDF) + model.setHandleInvalid("error") + + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + + model.setHandleInvalid("keep") + model.transform(testDF).collect() + } + + test("Transforming on mismatched attributes") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + + val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") + val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", testAttr.toMetadata())) + val err = intercept[Exception] { + model.transform(testDF).collect() + } + err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + } +} From f5b7714e0eed4530519df97eb4faca8b05dde161 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Jan 2018 08:47:12 -0600 Subject: [PATCH 550/712] [BUILD] Close stale PRs Closes #18916 Closes #19520 Closes #19613 Closes #19739 Closes #19936 Closes #19919 Closes #19933 Closes #19917 Closes #20027 Closes #19035 Closes #20044 Closes #20104 Author: Sean Owen Closes #20130 from srowen/StalePRs. From 7a702d8d5ed830de5d2237f136b08bd18deae037 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 2 Jan 2018 07:00:31 +0900 Subject: [PATCH 551/712] [SPARK-21616][SPARKR][DOCS] update R migration guide and vignettes ## What changes were proposed in this pull request? update R migration guide and vignettes ## How was this patch tested? manually Author: Felix Cheung Closes #20106 from felixcheung/rreleasenote23. --- R/pkg/tests/fulltests/test_Windows.R | 1 + R/pkg/vignettes/sparkr-vignettes.Rmd | 3 +-- docs/sparkr.md | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/fulltests/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R index b2ec6c67311db..209827d9fdc2f 100644 --- a/R/pkg/tests/fulltests/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 8c4ea2f2db188..2e662424b25f2 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -391,8 +391,7 @@ We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataF ```{r} carsSubDF <- select(carsDF, "model", "mpg") -schema <- structType(structField("model", "string"), structField("mpg", "double"), - structField("kmpg", "double")) +schema <- "model STRING, mpg DOUBLE, kmpg DOUBLE" out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` diff --git a/docs/sparkr.md b/docs/sparkr.md index a3254e7654134..997ea60fb6cf0 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -657,3 +657,9 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. - `spark.lda` was not setting the optimizer correctly. It has been corrected. - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. + +## Upgrading to SparkR 2.3.0 + + - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. + - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. + - A warning can be raised if versions of SparkR package and the Spark JVM do not match. From c284c4e1f6f684ca8db1cc446fdcc43b46e3413c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 31 Dec 2017 17:00:41 -0600 Subject: [PATCH 552/712] [MINOR] Fix a bunch of typos --- bin/find-spark-home | 2 +- .../java/org/apache/spark/util/kvstore/LevelDBIterator.java | 2 +- .../org/apache/spark/network/protocol/MessageWithHeader.java | 4 ++-- .../java/org/apache/spark/network/sasl/SaslEncryption.java | 4 ++-- .../org/apache/spark/network/util/TransportFrameDecoder.java | 2 +- .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 2 +- .../main/java/org/apache/spark/util/sketch/BloomFilter.java | 2 +- .../java/org/apache/spark/unsafe/array/ByteArrayMethods.java | 2 +- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- core/src/main/scala/org/apache/spark/status/storeTypes.scala | 2 +- .../test/scala/org/apache/spark/util/FileAppenderSuite.scala | 2 +- dev/github_jira_sync.py | 2 +- dev/lint-python | 2 +- examples/src/main/python/ml/linearsvc.py | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 2 +- .../apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java | 2 +- .../org/apache/spark/streaming/kinesis/KinesisUtils.scala | 4 ++-- .../java/org/apache/spark/launcher/ChildProcAppHandle.java | 2 +- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- python/pyspark/ml/image.py | 2 +- .../org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 2 +- .../spark/sql/catalyst/expressions/UnsafeArrayData.java | 2 +- .../scala/org/apache/spark/sql/catalyst/analysis/view.scala | 2 +- .../spark/sql/catalyst/expressions/objects/objects.scala | 5 +++-- .../expressions/aggregate/CountMinSketchAggSuite.scala | 2 +- .../sql/sources/v2/streaming/MicroBatchWriteSupport.java | 2 +- .../apache/spark/sql/execution/ui/static/spark-sql-viz.css | 2 +- .../apache/spark/sql/execution/datasources/FileFormat.scala | 2 +- .../spark/sql/execution/datasources/csv/CSVInferSchema.scala | 2 +- .../apache/spark/sql/execution/joins/HashedRelation.scala | 2 +- .../streaming/StreamingSymmetricHashJoinHelper.scala | 2 +- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 2 +- .../scala/org/apache/spark/sql/expressions/Aggregator.scala | 2 +- .../main/scala/org/apache/spark/sql/streaming/progress.scala | 2 +- .../src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java | 2 +- .../inputs/typeCoercion/native/implicitTypeCasts.sql | 2 +- .../execution/streaming/CompactibleFileStreamLogSuite.scala | 2 +- .../org/apache/spark/sql/sources/fakeExternalSources.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- .../test/scala/org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveShim.scala | 2 +- sql/hive/src/test/resources/data/conf/hive-log4j.properties | 2 +- .../org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 2 +- .../scala/org/apache/spark/streaming/util/StateMap.scala | 2 +- 45 files changed, 50 insertions(+), 49 deletions(-) diff --git a/bin/find-spark-home b/bin/find-spark-home index fa78407d4175a..617dbaa4fff86 100755 --- a/bin/find-spark-home +++ b/bin/find-spark-home @@ -21,7 +21,7 @@ FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" -# Short cirtuit if the user already has this set. +# Short circuit if the user already has this set. if [ ! -z "${SPARK_HOME}" ]; then exit 0 elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index b3ba76ba58052..f62e85d435318 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -86,7 +86,7 @@ class LevelDBIterator implements KVStoreIterator { end = index.start(parent, params.last); } if (it.hasNext()) { - // When descending, the caller may have set up the start of iteration at a non-existant + // When descending, the caller may have set up the start of iteration at a non-existent // entry that is guaranteed to be after the desired entry. For example, if you have a // compound key (a, b) where b is a, integer, you may seek to the end of the elements that // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 897d0f9e4fb89..a5337656cbd84 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -47,7 +47,7 @@ class MessageWithHeader extends AbstractFileRegion { /** * When the write buffer size is larger than this limit, I/O will be done in chunks of this size. * The size should not be too large as it will waste underlying memory copy. e.g. If network - * avaliable buffer is smaller than this limit, the data cannot be sent within one single write + * available buffer is smaller than this limit, the data cannot be sent within one single write * operation while it still will make memory copy with this size. */ private static final int NIO_BUFFER_LIMIT = 256 * 1024; @@ -100,7 +100,7 @@ public long transferred() { * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. * * The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transfered()). + * of bytes transferred so far (i.e. value returned by transferred()). */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 16ab4efcd4f5f..3ac9081d78a75 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -38,7 +38,7 @@ import org.apache.spark.network.util.NettyUtils; /** - * Provides SASL-based encription for transport channels. The single method exposed by this + * Provides SASL-based encryption for transport channels. The single method exposed by this * class installs the needed channel handlers on a connected channel. */ class SaslEncryption { @@ -166,7 +166,7 @@ static class EncryptedMessage extends AbstractFileRegion { * This makes assumptions about how netty treats FileRegion instances, because there's no way * to know beforehand what will be the size of the encrypted message. Namely, it assumes * that netty will try to transfer data from this message while - * transfered() < count(). So these two methods return, technically, wrong data, + * transferred() < count(). So these two methods return, technically, wrong data, * but netty doesn't know better. */ @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 50d9651ccbbb2..8e73ab077a5c1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -29,7 +29,7 @@ /** * A customized frame decoder that allows intercepting raw data. *

    - * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * This behaves like Netty's frame decoder (with hard coded parameters that match this library's * needs), except it allows an interceptor to be installed to read data directly before it's * framed. *

    diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 23438a08fa094..6d201b8fe8d7d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -127,7 +127,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); assertEquals(parsedShuffleInfo, shuffleInfo); - // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // Intentionally keep these hard-coded strings in here, to check backwards-compatibility. // its not legacy yet, but keeping this here in case anybody changes it String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c0b425e729595..37803c7a3b104 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -34,7 +34,7 @@ *

  • {@link String}
  • * * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that - * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that has * not actually been put in the {@code BloomFilter}. * * The implementation is largely based on the {@code BloomFilter} class from Guava. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index f121b1cd745b8..a6b1f7a16d605 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -66,7 +66,7 @@ public static boolean arrayEquals( i += 1; } } - // for architectures that suport unaligned accesses, chew it up 8 bytes at a time + // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (unaligned || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { while (i <= length - 8) { if (Platform.getLong(leftBase, leftOffset + i) != diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1dbb3789b0c9d..31f3cb9dfa0ae 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2276,7 +2276,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to serialized and send to tasks + * Clean a closure to make it ready to be serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index d9ead0071d3bf..1cfd30df49091 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -61,7 +61,7 @@ private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { /** * Keep track of the existing stages when the job was submitted, and those that were - * completed during the job's execution. This allows a more accurate acounting of how + * completed during the job's execution. This allows a more accurate accounting of how * many tasks were skipped for the job. */ private[spark] class JobDataWrapper( diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index cd0ed5b036bf9..52cd5378bc715 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -356,7 +356,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { generatedFiles } - /** Delete all the generated rolledover files */ + /** Delete all the generated rolled over files */ def cleanup() { testFile.getParentFile.listFiles.filter { file => file.getName.startsWith(testFile.getName) diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index acc9aeabbb9fb..7df7b325cabe0 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -43,7 +43,7 @@ # "notification overload" when running for the first time. MIN_COMMENT_PR = int(os.environ.get("MIN_COMMENT_PR", "1496")) -# File used as an opitimization to store maximum previously seen PR +# File used as an optimization to store maximum previously seen PR # Used mostly because accessing ASF JIRA is slow, so we want to avoid checking # the state of JIRA's that are tied to PR's we've already looked at. MAX_FILE = ".github-jira-max" diff --git a/dev/lint-python b/dev/lint-python index 07e2606d45143..df8df037a5f69 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -# Exclude auto-geneated configuration file. +# Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" diff --git a/examples/src/main/python/ml/linearsvc.py b/examples/src/main/python/ml/linearsvc.py index 18cbf87a10695..9b79abbf96f88 100644 --- a/examples/src/main/python/ml/linearsvc.py +++ b/examples/src/main/python/ml/linearsvc.py @@ -37,7 +37,7 @@ # Fit the model lsvcModel = lsvc.fit(training) - # Print the coefficients and intercept for linearsSVC + # Print the coefficients and intercept for linear SVC print("Coefficients: " + str(lsvcModel.coefficients)) print("Intercept: " + str(lsvcModel.intercept)) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 9d9e2aaba8079..66b3409c0cd04 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsuemrs (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 08db4d827e400..75245943c4936 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -201,7 +201,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L verifyTopicDeletionWithRetries(zkUtils, topic, partitions, List(this.server)) } - /** Add new paritions to a Kafka topic */ + /** Add new partitions to a Kafka topic */ def addPartitions(topic: String, partitions: Int): Unit = { AdminUtils.addPartitions(zkUtils, topic, partitions) // wait until metadata is propagated diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java index 87bfe1514e338..b20fad2291262 100644 --- a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java @@ -111,7 +111,7 @@ public String call(ConsumerRecord r) { LocationStrategies.PreferConsistent() ).map(handler); - // just making sure the java user apis work; the scala tests handle logic corner cases + // just making sure the java user APIs work; the scala tests handle logic corner cases long count1 = rdd1.count(); long count2 = rdd2.count(); Assert.assertTrue(count1 > 0); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2500460bd330b..c60b9896a3473 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -164,7 +164,7 @@ object KinesisUtils { * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * @param stsSessionName Name to uniquely identify STS sessions if multiple principals assume * the same role. * @param stsExternalId External ID that can be used to validate against the assumed IAM role's * trust policy. @@ -434,7 +434,7 @@ object KinesisUtils { * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * @param stsSessionName Name to uniquely identify STS sessions if multiple princpals assume * the same role. * @param stsExternalId External ID that can be used to validate against the assumed IAM role's * trust policy. diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 3bb7e12385fd8..8b3f427b7750e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -71,7 +71,7 @@ void setChildProc(Process childProc, String loggerName, InputStream logStream) { } /** - * Wait for the child process to exit and update the handle's state if necessary, accoding to + * Wait for the child process to exit and update the handle's state if necessary, according to * the exit code. */ void monitorChild() { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 0130b3e255f0d..095b54c0fe83f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -94,7 +94,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 384599dc0c532..c9b840276f675 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -212,7 +212,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, ImageSchema = _ImageSchema() -# Monkey patch to disallow instantization of this class. +# Monkey patch to disallow instantiation of this class. def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index ab0005d7b53a8..061f653b97b7a 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -95,7 +95,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 64ab01ca57403..d18542b188f71 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -294,7 +294,7 @@ public void setNullAt(int ordinal) { assertIndexIsValid(ordinal); BitSetMethods.set(baseObject, baseOffset + 8, ordinal); - /* we assume the corrresponding column was already 0 or + /* we assume the corresponding column was already 0 or will be set to 0 later by the caller side */ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bbe41cf8f15e..20216087b0158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.internal.SQLConf * view resolution, in this way, we are able to get the correct view column ordering and * omit the extra columns that we don't require); * 1.2. Else set the child output attributes to `queryOutput`. - * 2. Map the `queryQutput` to view output by index, if the corresponding attributes don't match, + * 2. Map the `queryOutput` to view output by index, if the corresponding attributes don't match, * try to up cast and alias the attribute in `queryOutput` to the attribute in the view output. * 3. Add a Project over the child, with the new output generated by the previous steps. * If the view output doesn't have the same number of columns neither with the child output, nor diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4af813456b790..64da9bb9cdec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -51,7 +51,7 @@ trait InvokeLike extends Expression with NonSQLExpression { * * - generate codes for argument. * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. - * - avoid some of nullabilty checking which are not needed because the expression is not + * - avoid some of nullability checking which are not needed because the expression is not * nullable. * - when needNullCheck == true, short circuit if we found one of arguments is null because * preparing rest of arguments can be skipped in the case. @@ -193,7 +193,8 @@ case class StaticInvoke( * @param targetObject An expression that will return the object to call the method on. * @param functionName The name of the method to call. * @param dataType The expected return type of the function. - * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param arguments An optional list of expressions, whose evaluation will be passed to the + * function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. * @param returnNullable When false, indicating the invoked method will always return diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala index 10479630f3f99..30e3bc9fb5779 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch /** - * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]]. + * Unit test suite for the count-min sketch SQL aggregate function [[CountMinSketchAgg]]. */ class CountMinSketchAggSuite extends SparkFunSuite { private val childExpression = BoundReference(0, IntegerType, nullable = true) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java index b5e3e44cd6332..53ffa95ae0f4c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java @@ -41,7 +41,7 @@ public interface MicroBatchWriteSupport extends BaseStreamingSink { * @param queryId A unique string for the writing query. It's possible that there are many writing * queries running at the same time, and the returned {@link DataSourceV2Writer} * can use this id to distinguish itself from others. - * @param epochId The uniquenumeric ID of the batch within this writing query. This is an + * @param epochId The unique numeric ID of the batch within this writing query. This is an * incrementing counter representing a consistent set of data; the same batch may * be started multiple times in failure recovery scenarios, but it will always * contain the same records. diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index 594e747a8d3a5..b13850c301490 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -32,7 +32,7 @@ stroke-width: 1px; } -/* Hightlight the SparkPlan node name */ +/* Highlight the SparkPlan node name */ #plan-viz-graph svg text :first-child { font-weight: bold; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index d3874b58bc807..023e127888290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -77,7 +77,7 @@ trait FileFormat { } /** - * Returns whether a file with `path` could be splitted or not. + * Returns whether a file with `path` could be split or not. */ def isSplitable( sparkSession: SparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index b64d71bb4eef2..a585cbed2551b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -150,7 +150,7 @@ private[csv] object CSVInferSchema { if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwords competibility. + // We keep this for backwards compatibility. TimestampType } else { tryParseBoolean(field, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index d98cf852a1b48..1465346eb802d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -368,7 +368,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The minimum key private var minKey = Long.MaxValue - // The maxinum key + // The maximum key private var maxKey = Long.MinValue // The array to store the key and offset of UnsafeRow in the page. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 217e98a6419e9..4aba76cad367e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -203,7 +203,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already - * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]]. + * loaded. This is class is a modified version of [[ZippedPartitionsRDD2]]. */ class StateStoreAwareZipPartitionsRDD[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2295b8dd5fe36..d8adbe7bee13e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -175,7 +175,7 @@ class SQLAppStatusListener( // Check the execution again for whether the aggregated metrics data has been calculated. // This can happen if the UI is requesting this data, and the onExecutionEnd handler is - // running at the same time. The metrics calculcated for the UI can be innacurate in that + // running at the same time. The metrics calculated for the UI can be innacurate in that // case, since the onExecutionEnd handler will clean up tracked stage metrics. if (exec.metricsValues != null) { exec.metricsValues diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 058c38c8cb8f4..1e076207bc607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -86,7 +86,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable { def bufferEncoder: Encoder[BUF] /** - * Specifies the `Encoder` for the final ouput value type. + * Specifies the `Encoder` for the final output value type. * @since 2.0.0 */ def outputEncoder: Encoder[OUT] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index cedc1dce4a703..0dcb666e2c3e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -152,7 +152,7 @@ class StreamingQueryProgress private[sql]( * @param endOffset The ending offset for data being read. * @param numInputRows The number of records read from this source. * @param inputRowsPerSecond The rate at which data is arriving from this source. - * @param processedRowsPerSecond The rate at which data from this source is being procressed by + * @param processedRowsPerSecond The rate at which data from this source is being processed by * Spark. * @since 2.1.0 */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index 447a71d284fbb..288f5e7426c05 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -47,7 +47,7 @@ public MyDoubleAvg() { _inputDataType = DataTypes.createStructType(inputFields); // The buffer has two values, bufferSum for storing the current sum and - // bufferCount for storing the number of non-null input values that have been contribuetd + // bufferCount for storing the number of non-null input values that have been contributed // to the current sum. List bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql index 58866f4b18112..6de22b8b7c3de 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/implicitTypeCasts.sql @@ -32,7 +32,7 @@ SELECT 1.1 - '2.2' FROM t; SELECT 1.1 * '2.2' FROM t; SELECT 4.4 / '2.2' FROM t; --- concatentation +-- concatenation SELECT '$' || cast(1 as smallint) || '$' FROM t; SELECT '$' || 1 || '$' FROM t; SELECT '$' || cast(1 as bigint) || '$' FROM t; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 83018f95aa55d..12eaf63415081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -92,7 +92,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext test("deriveCompactInterval") { // latestCompactBatchId(4) + 1 <= default(5) - // then use latestestCompactBatchId + 1 === 5 + // then use latestCompactBatchId + 1 === 5 assert(5 === deriveCompactInterval(5, 4)) // First divisor of 10 greater than 4 === 5 assert(5 === deriveCompactInterval(4, 9)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala index bf43de597a7a0..2cb48281b30af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationP import org.apache.spark.sql.types._ -// Note that the package name is intendedly mismatched in order to resemble external data sources +// Note that the package name is intentionally mismatched in order to resemble external data sources // and test the detection for them. class FakeExternalSourceOne extends RelationProvider with DataSourceRegister { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 2a2552211857a..8c4e1fd00b0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -81,7 +81,7 @@ class FileStreamSinkSuite extends StreamTest { .start(outputDir) try { - // The output is partitoned by "value", so the value will appear in the file path. + // The output is partitioned by "value", so the value will appear in the file path. // This is to test if we handle spaces in the path correctly. inputData.addData("hello world") failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b4248b74f50ab..904f9f2ad0b22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -113,7 +113,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with if (thread.isAlive) { thread.interrupt() // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and + // is not interruptible. There is not much point to wait for the thread to terminate, and // we rather let the JVM terminate the thread on exit. fail( s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 9e9894803ce25..11afe1af32809 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -50,7 +50,7 @@ private[hive] object HiveShim { val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + * This function in hive-0.13 become private, but we have to do this to work around hive bug */ private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") diff --git a/sql/hive/src/test/resources/data/conf/hive-log4j.properties b/sql/hive/src/test/resources/data/conf/hive-log4j.properties index 6a042472adb90..83fd03a99bc37 100644 --- a/sql/hive/src/test/resources/data/conf/hive-log4j.properties +++ b/sql/hive/src/test/resources/data/conf/hive-log4j.properties @@ -32,7 +32,7 @@ log4j.threshhold=WARN log4j.appender.DRFA=org.apache.log4j.DailyRollingFileAppender log4j.appender.DRFA.File=${hive.log.dir}/${hive.log.file} -# Rollver at midnight +# Roll over at midnight log4j.appender.DRFA.DatePattern=.yyyy-MM-dd # 30-day backup diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 15d3c7e54b8dd..8da5a5f8193cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -162,7 +162,7 @@ private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, mappingFunction, batchTime, timeoutThresholdTime, - removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + removeTimedoutData = doFullScan // remove timed-out data only when full scan is enabled ) Iterator(newRecord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3a21cfae5ac2f..89524cd84ff32 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -364,7 +364,7 @@ private[streaming] object OpenHashMapBasedStateMap { } /** - * Internal class to represent a marker the demarkate the end of all state data in the + * Internal class to represent a marker that demarcates the end of all state data in the * serialized bytes. */ class LimitMarker(val num: Int) extends Serializable From 1c9f95cb771ac78775a77edd1abfeb2d8ae2a124 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 2 Jan 2018 07:13:27 +0900 Subject: [PATCH 553/712] [SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType ## What changes were proposed in this pull request? This change adds `ArrayType` support for working with Arrow in pyspark when creating a DataFrame, calling `toPandas()`, and using vectorized `pandas_udf`. ## How was this patch tested? Added new Python unit tests using Array data. Author: Bryan Cutler Closes #20114 from BryanCutler/arrow-ArrayType-support-SPARK-22530. --- python/pyspark/sql/tests.py | 47 ++++++++++++++++++- python/pyspark/sql/types.py | 4 ++ .../vectorized/ArrowColumnVector.java | 13 ++++- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1c34c897eecb5..67bdb3d72d93b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3372,6 +3372,31 @@ def test_schema_conversion_roundtrip(self): schema_rt = from_arrow_schema(arrow_schema) self.assertEquals(self.schema, schema_rt) + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): @@ -3651,6 +3676,24 @@ def test_vectorized_udf_datatype_string(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), ([3, 4],)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_null_array(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( @@ -3705,7 +3748,7 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) + f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() @@ -4009,7 +4052,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, - 'id long, v array', + 'id long, v map', PandasUDFType.GROUP_MAP ) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 02b245713b6ab..146e673ae9756 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1625,6 +1625,8 @@ def to_arrow_type(dt): elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == ArrayType: + arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type @@ -1665,6 +1667,8 @@ def from_arrow_type(at): spark_type = DateType() elif types.is_timestamp(at): spark_type = TimestampType() + elif types.is_list(at): + spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 528f66f342dc9..af5673e26a501 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -326,7 +326,8 @@ private abstract static class ArrowVectorAccessor { this.vector = vector; } - final boolean isNullAt(int rowId) { + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { return vector.isNull(rowId); } @@ -589,6 +590,16 @@ private static class ArrayAccessor extends ArrowVectorAccessor { this.accessor = vector; } + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + @Override final int getArrayLength(int rowId) { return accessor.getInnerValueCountAt(rowId); From e734a4b9c23463a7fea61011027a822bc9e11c98 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Jan 2018 07:20:05 +0900 Subject: [PATCH 554/712] [SPARK-21893][SPARK-22142][TESTS][FOLLOWUP] Enables PySpark tests for Flume and Kafka in Jenkins ## What changes were proposed in this pull request? This PR proposes to enable PySpark tests for Flume and Kafka in Jenkins by explicitly setting the environment variables in `modules.py`. Seems we are not taking the dependencies into account when calculating environment variables: https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L554-L561 ## How was this patch tested? Manual tests with Jenkins in https://github.com/apache/spark/pull/20126. **Before** - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/85559/consoleFull ``` [info] Setup the following environment variables for tests: ... ``` **After** - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/85560/consoleFull ``` [info] Setup the following environment variables for tests: ENABLE_KAFKA_0_8_TESTS=1 ENABLE_FLUME_TESTS=1 ... ``` Author: hyukjinkwon Closes #20128 from HyukjinKwon/SPARK-21893. --- dev/sparktestsupport/modules.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 44f990ec8a5ac..f834563da9dda 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -418,6 +418,10 @@ def __hash__(self): source_file_regexes=[ "python/pyspark/streaming" ], + environ={ + "ENABLE_FLUME_TESTS": "1", + "ENABLE_KAFKA_0_8_TESTS": "1" + }, python_test_goals=[ "pyspark.streaming.util", "pyspark.streaming.tests", From e0c090f227e9b64e595b47d4d1f96f8a2fff5bf7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 555/712] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee7..35b35110e491f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From a6fc300e91273230e7134ac6db95ccb4436c6f8f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 556/712] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4fa..69739745aa6cf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb06..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a06..f536fc2a5f0a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1e..f8e233a05a447 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0a..ced5a06516f75 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc248..dcf89e4f75acf 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085d..aa9c36c0aaacb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f11..9c0699bc981f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813ea..3b452f35c5ec1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae7998..3e31d22e15c0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d0..6af9f8b77f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9a..a3ae93810aa3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bedf..3fad7dfddadcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From 247a08939d58405aef39b2a4e7773aa45474ad12 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 557/712] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af57..80cdc61484c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 558/712] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d55..6daf01d98426c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e6..89347f4b1f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972c..d3cfd2a1ffbf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b602..2b1aea08b1223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9cb..591510c1d8283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38b..65be244418670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae945848..47adc77a52d51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From a66fe36cee9363b01ee70e469f1c968f633c5713 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 559/712] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af4530..6d0059b6a0272 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9c..6d20ef1f98a3c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0f..5d6edf6b8abec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942a..dd7ef0d15c140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b1..39c594a9bc618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f45946..fef01c860db6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 560/712] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac621102..43cc30c1a899b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d94..2332a661f26a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc0..815404d1218b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742aa..57b2edf992208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce58..5146fd0316467 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb00..d7f1fc8ed74d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf16961..ee4469faab3a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933ea..276cedab11abc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0c..aaaecaea47081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e96..2dc11b07d88ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a6921..e5d91f132a3f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b817..ef78c0a1145ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d3..3feb2343f6a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353d..3e61dbe628c20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8d..646f46a925062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66c..50c70c626b128 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb70..0fe16fb6dfa9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce4285..6265f83902528 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef5..2679fcb353a8a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b98..96bb8ea2338af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e5..a07535bb5a38d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff6..c6312d71cc912 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777ce..c2f89b72c9a2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5b..1ecf6426e1f95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc127752..f724ee1030f04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a36..3c56e1941aeca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839b..c288bf29bf255 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c07..add1719739539 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04c..a10d6f0dda880 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f375..b0a6f1671a898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa098381..123782fa6b9cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f1..d25962c5500ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d399618094487..449b725d1d173 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba1..eff2393cc3abe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb11..96deafd469bc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c4336576..8b789277774af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733ed..246e71de25615 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4b..770e30276bc30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e70..0bb2b8c8c2b43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3faa..285e2ce512639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba09..694c3bb18b045 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef60699..3d41bef0af88c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5a..071d341b81614 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e7..8ae6de16d80e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b1204..25c7bf2871972 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b4..437ccf0898d7c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c0427321133..f018f3a26d2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2b..2108bc63edea2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32c..b8e7c7e9e9152 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From b297029130735316e1ac1144dee44761a12bfba7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 561/712] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411f..cd745b1f0e4e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e5..b5cbe8e2839ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe9..70057a9def6c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e12..d2ae32b06f83b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a501..708333213f3f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd9..d1196e1299fee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec2..0d89a52e7a4fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa679726..9ae1c6d9993f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c0..3c6656dec77cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..5617046e1396e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a9..ce3c68810f3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d63..0cf9b53ce1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc412430263..bcd1aa0890ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b12850..933b9753faa61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f2..835ce98462477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed3535654..dc5ba96e69aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92e..c42bc60a59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a655..7304803a092c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f6..944240f3bade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d0..675f06b31b970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From 7d045c5f00e2c7c67011830e2169a4e130c3ace8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 562/712] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc32..1c0b7bd806801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a8..c28844642aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From df95a908baf78800556636a76d58bba9b3dd943f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 563/712] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47b..c51eb0f39c4b1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb0..9956f7eda91e6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb8..9d0a2d5e074e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5e..e0dde3339fabc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f517..a354d50c6b54e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7ab..24ae3776a217b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c08..19e3e55cb2829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From 9fa703e89318922393bae03c0db4575f4f4b4c56 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 564/712] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd1..c7717d70c996f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a68440..8697d47e89e89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From d5861aba9d80ca15ad3f22793b79822e470d6913 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 565/712] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e41..f94baef39dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f830520..40a058d2cadd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01dec..0d11682d80a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..4f8a31f185724 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7a..db37be68e42e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 5aadbc929cb194e06dbd3bab054a161569289af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 566/712] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0c..156603128d063 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef3..b8d86cc098e94 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93b..6dc767f9ec46e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c8..5e75eb6545333 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From 6f68316e98fad72b171df422566e1fc9a7bbfcde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 567/712] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961b..41dc762154a4c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e27..97ddbeba2c5ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From 93f92c0ed7442a4382e97254307309977ff676f8 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 568/712] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..8b8f9892847c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b76..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From d2cddc88eac32f26b18ec26bb59e85c6f09a8c88 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 569/712] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b1025..ba6387a8f08ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f0..da6ecb82c7e42 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 95f9659abe8845f9f3f42fd7ababd79e55c52489 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 570/712] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dda..7164180a6a7b0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbbf..c0f08786b76a1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9a..e0f29ecd0fb53 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb7..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From e288fc87a027ec1e1a21401d1f151df20dbfecf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 571/712] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce0641..c35e7db51d407 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee49177..c9cc300d65569 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd472..eca46b84c6066 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d1..91e9a9f211335 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8c..0daa7b95e8aae 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77e..066d7e9f70ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded268..16780584a674a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb1..8ee629ac8ddc1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5aa..960d0bda1d011 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e17659456..7ac0bde80dfe6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c1..884da8aabd880 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From 0428368c2c5e135f99f62be20877bbbda43be310 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 572/712] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3c..2d69f636472ae 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0deb..ff5289e10c21e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148e..3eabb42d4d852 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633b..e0a249e0ac71f 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a82..da1d6b9e161cc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d8..bb8806dd33f37 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From df7fc3ef3899cadd252d2837092bebe3442d6523 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 573/712] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed5..a45854a3b5146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6af..b013add9c9778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2e..5169d2b5fc6b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 52fc5c17d9d784b846149771b398e741621c0b5c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 574/712] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..f0f66bae245fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From cf0aa65576acbe0209c67f04c029058fd73555c1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 575/712] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83f..a0b507d2e718c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..8826ef3271bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From 6cff7d19f6a905fe425bd6892fe7ca014c0e696b Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 576/712] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a1..1e381965c52ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From 51c33bd0d402af9e0284c6cbc0111f926446bfba Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 577/712] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5e..ff2a0ec588567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38dea..9c0a30a47f839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec72..a0708bf7eee9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 578/712] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From 930b90a84871e2504b57ed50efa7b8bb52d3ba44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 579/712] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28d..bd1e3426c8780 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From ea956833017fcbd8ed2288368bfa2e417a2251c5 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 580/712] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30f..0ec4afad0308c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a4..efdbf672bb52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d72633..4d5e3bb043671 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 0000000000000..695a82f3583e6 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From e8af7e8aeca15a6107248f358d9514521ffdc6d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 581/712] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d27..b50f9360b866c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2e..e8669c4637d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4c..e004bfc6af473 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abec..80b8965e084a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622b..52a7ebdafd7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 0000000000000..717616f91db05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 0000000000000..b62e1b6826045 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From bf65cd3cda46d5480bfcd13110975c46ca631972 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 582/712] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21e..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d852..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71f..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f37..b9532597419a5 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 583/712] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b4..733e32bd825b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..689736d8e6456 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From be9a804f2ef77a5044d3da7d9374976daf59fc16 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 584/712] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..dc92ad3b0c1ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 7b78041423b6ee330def2336dfd1ff9ae8469c59 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 585/712] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866c..3ccaaf4d5b1fa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a2..7d1217de254a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..ef67ea7d17cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 0000000000000..ed8fd2b453456 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From 993f21567a1dd33e43ef9a626e0ddfbe46f83f93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 586/712] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e6456..122a65b83aef9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb6545333..5e80ab9165867 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From 9a7048b2889bd0fd66e68a0ce3e07e466315a051 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 587/712] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a2..5c61f10bb71ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 18e94149992618a2b4e6f0fd3b3f4594e1745224 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 588/712] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a7..f2de4c8e30bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd8..1445bb8a97d40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } From 71d65a32158a55285be197bec4e41fedc9225b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 11:39:45 +0800 Subject: [PATCH 589/712] [SPARK-22985] Fix argument escaping bug in from_utc_timestamp / to_utc_timestamp codegen ## What changes were proposed in this pull request? This patch adds additional escaping in `from_utc_timestamp` / `to_utc_timestamp` expression codegen in order to a bug where invalid timezones which contain special characters could cause generated code to fail to compile. ## How was this patch tested? New regression tests in `DateExpressionsSuite`. Author: Josh Rosen Closes #20182 from JoshRosen/SPARK-22985-fix-utc-timezone-function-escaping-bugs. --- .../catalyst/expressions/datetimeExpressions.scala | 12 ++++++++---- .../catalyst/expressions/DateExpressionsSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7a674ea7f4d76..424871f2047e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,6 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -1008,7 +1010,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1017,8 +1019,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1185,7 +1188,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1194,8 +1197,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63f6ceeb21b96..786266a2c13c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -791,6 +792,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate( + ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil) } test("from_utc_timestamp") { @@ -811,5 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil) } } From 3e40eb3f1ffac3d2f49459a801e3ce171ed34091 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 8 Jan 2018 14:32:05 +0900 Subject: [PATCH 590/712] [SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark DF conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It provides a better error message when doing `spark_session.createDataFrame(pandas_df)` with no schema and an error occurs in the schema inference due to incompatible types. The Pandas column names are propagated down and the error message mentions which column had the merging error. https://issues.apache.org/jira/browse/SPARK-22566 ## How was this patch tested? Manually in the `./bin/pyspark` console, and with new tests: `./python/run-tests` screen shot 2017-11-21 at 13 29 49 I state that the contribution is my original work and that I license the work to the Apache Spark project under the project’s open source license. Author: Guilherme Berger Closes #19792 from gberger/master. --- python/pyspark/sql/session.py | 17 +++--- python/pyspark/sql/tests.py | 100 ++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 28 +++++++--- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6e5eec48e8aca..6052fa9e84096 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +373,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,7 +381,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) + struct = self._inferSchema(rdd, samplingRatio, names=schema) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): @@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 122a65b83aef9..13576ff57001b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,6 +68,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings +from pyspark.sql.types import _merge_type from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -898,6 +899,15 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), @@ -918,6 +928,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) @@ -1772,6 +1786,92 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 146e673ae9756..0dc5823f72a3c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1073,7 +1073,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1084,7 +1084,10 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1109,19 +1112,27 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1130,11 +1141,12 @@ def _merge_type(a, b): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a From 8fdeb4b9946bd9be045abb919da2e531708b3bd4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Jan 2018 13:59:08 +0800 Subject: [PATCH 591/712] [SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava) ## What changes were proposed in this pull request? Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead. I manually performed the benchmark as below: ```scala test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") { val numRows = 1000 * 1000 val numFields = 30 val random = new Random(System.nanoTime()) val types = Array( BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) val schema = RandomDataGenerator.randomSchema(random, numFields, types) val rows = mutable.ArrayBuffer.empty[Array[Any]] var i = 0 while (i < numRows) { val row = RandomDataGenerator.randomRow(random, schema) rows += row.toSeq.toArray i += 1 } val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows) benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ => var i = 0 while (i < numRows) { EvaluatePython.fromJava(rows(i), schema) i += 1 } } benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ => val fromJava = EvaluatePython.makeFromJava(schema) var i = 0 while (i < numRows) { fromJava(rows(i)) i += 1 } } benchmark.run() } ``` ``` EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Before - EvaluatePython.fromJava 1265 / 1346 0.8 1264.8 1.0X After - EvaluatePython.makeFromJava 571 / 649 1.8 570.8 2.2X ``` If the structure is nested, I think the advantage should be larger than this. ## How was this patch tested? Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks. Author: hyukjinkwon Closes #20172 from HyukjinKwon/type-dispatch-python-eval. --- .../org/apache/spark/sql/SparkSession.scala | 5 +- .../python/BatchEvalPythonExec.scala | 7 +- .../sql/execution/python/EvaluatePython.scala | 166 ++++++++++++------ 3 files changed, 118 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 272eb844226d4..734573ba31f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -742,7 +742,10 @@ class SparkSession private( private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.mapPartitions { iter => + val fromJava = python.EvaluatePython.makeFromJava(schema) + iter.map(r => fromJava(r).asInstanceOf[InternalRow]) + } internalCreateDataFrame(rowRdd, schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea4..f4d83e8dc7c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } + + val fromJava = EvaluatePython.makeFromJava(resultType) + outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow(0) = fromJava(result) mutableRow } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + fromJava(result).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba77..520afad287648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -83,82 +83,134 @@ object EvaluatePython { } /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. + * Make a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c + def makeFromJava(dataType: DataType): Any => Any = dataType match { + case BooleanType => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } - case (c: Byte, ByteType) => c - case (c: Short, ByteType) => c.toByte - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte + case ByteType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } - case (c: Byte, ShortType) => c.toShort - case (c: Short, ShortType) => c - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort + case ShortType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } - case (c: Byte, IntegerType) => c.toInt - case (c: Short, IntegerType) => c.toInt - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt + case IntegerType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } - case (c: Byte, LongType) => c.toLong - case (c: Short, LongType) => c.toLong - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c + case LongType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } - case (c: Float, FloatType) => c - case (c: Double, FloatType) => c.toFloat + case FloatType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } - case (c: Float, DoubleType) => c.toDouble - case (c: Double, DoubleType) => c + case DoubleType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale) + } - case (c: Int, DateType) => c + case DateType => (obj: Any) => nullSafeConvert(obj) { + case c: Int => c + } - case (c: Long, TimestampType) => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case (c: Int, TimestampType) => c.toLong + case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } - case (c, StringType) => UTF8String.fromString(c.toString) + case StringType => (obj: Any) => nullSafeConvert(obj) { + case _ => UTF8String.fromString(obj.toString) + } - case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case BinaryType => (obj: Any) => nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + case ArrayType(elementType, _) => + val elementFromJava = makeFromJava(elementType) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray) + case c if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e))) + } - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => - ArrayBasedMapData( - javaMap, - (key: Any) => fromJava(key, keyType), - (value: Any) => fromJava(value, valueType)) + case MapType(keyType, valueType, _) => + val keyFromJava = makeFromJava(keyType) + val valueFromJava = makeFromJava(valueType) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + ArrayBasedMapData( + javaMap, + (key: Any) => keyFromJava(key), + (value: Any) => valueFromJava(value)) + } - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) + case StructType(fields) => + val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + + val row = new GenericInternalRow(fields.length) + var i = 0 + while (i < fields.length) { + row(i) = fieldsFromJava(i)(array(i)) + i += 1 + } + row } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + + case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + } - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + _: Any => null + }) + } } private val module = "pyspark.sql.types" From 2c73d2a948bdde798aaf0f87c18846281deb05fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 16:04:03 +0800 Subject: [PATCH 592/712] [SPARK-22983] Don't push filters beneath aggregates with empty grouping expressions ## What changes were proposed in this pull request? The following SQL query should return zero rows, but in Spark it actually returns one row: ``` SELECT 1 from ( SELECT 1 AS z, MIN(a.x) FROM (select 1 as x) a WHERE false ) b where b.z != b.z ``` The problem stems from the `PushDownPredicate` rule: when this rule encounters a filter on top of an Aggregate operator, e.g. `Filter(Agg(...))`, it removes the original filter and adds a new filter onto Aggregate's child, e.g. `Agg(Filter(...))`. This is sometimes okay, but the case above is a counterexample: because there is no explicit `GROUP BY`, we are implicitly computing a global aggregate over the entire table so the original filter was not acting like a `HAVING` clause filtering the number of groups: if we push this filter then it fails to actually reduce the cardinality of the Aggregate output, leading to the wrong answer. In 2016 I fixed a similar problem involving invalid pushdowns of data-independent filters (filters which reference no columns of the filtered relation). There was additional discussion after my fix was merged which pointed out that my patch was an incomplete fix (see #15289), but it looks I must have either misunderstood the comment or forgot to follow up on the additional points raised there. This patch fixes the problem by choosing to never push down filters in cases where there are no grouping expressions. Since there are no grouping keys, the only columns are aggregate columns and we can't push filters defined over aggregate results, so this change won't cause us to miss out on any legitimate pushdown opportunities. ## How was this patch tested? New regression tests in `SQLQueryTestSuite` and `FilterPushdownSuite`. Author: Josh Rosen Closes #20180 from JoshRosen/SPARK-22983-dont-push-filters-beneath-aggs-with-empty-grouping-expressions. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../catalyst/optimizer/FilterPushdownSuite.scala | 13 +++++++++++++ .../test/resources/sql-tests/inputs/group-by.sql | 9 +++++++++ .../resources/sql-tests/results/group-by.sql.out | 16 +++++++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d4b02c6e7d8a..df0af8264a329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -795,7 +795,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) - if aggregate.aggregateExpressions.forall(_.deterministic) => + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 85a5e979f6021..82a10254d846d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -809,6 +809,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("aggregate: don't push filters if the aggregate has no grouping expressions") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy()(count(1)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1e1384549a410..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; -- Aggregate with empty input and empty GroupBy expressions. SELECT COUNT(1) FROM testData WHERE false; SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; + +-- Aggregate with empty GroupBy expressions and filter on top +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 986bb01c13fe4..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 26 -- !query 0 @@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t struct<1:int> -- !query 24 output 1 + + +-- !query 25 +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z +-- !query 25 schema +struct<1:int> +-- !query 25 output + From eb45b52e826ea9cea48629760db35ef87f91fea0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Jan 2018 19:41:41 +0800 Subject: [PATCH 593/712] [SPARK-21865][SQL] simplify the distribution semantic of Spark SQL ## What changes were proposed in this pull request? **The current shuffle planning logic** 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`. 6. If the check in 5 failed, add a shuffle above each child. 7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`. This design has a major problem with the definition of "compatible". `Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it. As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children. I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`. I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements. **Proposed shuffle planning logic after this PR** (The first 4 are same as before) 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings have the same number of partitions. 6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one. The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19080 from cloud-fan/exchange. --- .../plans/physical/partitioning.scala | 286 +++++++----------- .../sql/catalyst/PartitioningSuite.scala | 55 ---- .../spark/sql/execution/SparkPlan.scala | 16 +- .../exchange/EnsureRequirements.scala | 120 +++----- .../joins/ShuffledHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 81 ++--- 8 files changed, 194 insertions(+), 370 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a36..0189bd73c56bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** + * The required number of partitions for this distribution. If it's None, then any number of + * partitions is allowed for this distribution. + */ + def requiredNumPartitions: Option[Int] + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution while + * matching the given number of partitions. + */ + def createPartitioning(numPartitions: Int): Partitioning +} /** * Represents a distribution where no promises are made about co-location of data. */ -case object UnspecifiedDistribution extends Distribution +case object UnspecifiedDistribution extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.") + } +} /** * Represents a distribution that only has a single partition and all tuples of the dataset * are co-located. */ -case object AllTuples extends Distribution +case object AllTuples extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.") + SinglePartition + } +} /** * Represents data where tuples that share the same values for the `clustering` @@ -51,12 +76,41 @@ case object AllTuples extends Distribution */ case class ClusteredDistribution( clustering: Seq[Expression], - numPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(clustering, numPartitions) + } +} + +/** + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * [[HashPartitioning]] can satisfy this distribution. + * + * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the + * number of partitions, this distribution strictly requires which partition the tuple should be in. + */ +case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + HashPartitioning(expressions, numPartitions) + } } /** @@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - // TODO: This is not really valid... - def clustering: Set[Expression] = ordering.map(_.child).toSet + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + RangePartitioning(ordering, numPartitions) + } } /** * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, + "The default partitioning of BroadcastDistribution can only have 1 partition.") + BroadcastPartitioning(mode) + } +} /** - * Describes how an operator's output is split across partitions. The `compatibleWith`, - * `guarantees`, and `satisfies` methods describe relationships between child partitionings, - * target partitionings, and [[Distribution]]s. These relations are described more precisely in - * their individual method docs, but at a high level: - * - * - `satisfies` is a relationship between partitionings and distributions. - * - `compatibleWith` is relationships between an operator's child output partitionings. - * - `guarantees` is a relationship between a child's existing output partitioning and a target - * output partitioning. - * - * Diagrammatically: - * - * +--------------+ - * | Distribution | - * +--------------+ - * ^ - * | - * satisfies - * | - * +--------------+ +--------------+ - * | Child | | Target | - * +----| Partitioning |----guarantees--->| Partitioning | - * | +--------------+ +--------------+ - * | ^ - * | | - * | compatibleWith - * | | - * +------------+ - * + * Describes how an operator's output is split across partitions. It has 2 major properties: + * 1. number of partitions. + * 2. if it can satisfy a given distribution. */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ @@ -123,113 +162,35 @@ sealed trait Partitioning { * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] - * guarantees the same partitioning scheme described by `other`. - * - * Compatibility of partitionings is only checked for operators that have multiple children - * and that require a specific child output [[Distribution]], such as joins. - * - * Intuitively, partitionings are compatible if they route the same partitioning key to the same - * partition. For instance, two hash partitionings are only compatible if they produce the same - * number of output partitionings and hash records according to the same hash function and - * same partitioning key schema. - * - * Put another way, two partitionings are compatible with each other if they satisfy all of the - * same distribution guarantees. - */ - def compatibleWith(other: Partitioning): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees - * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning - * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance - * optimization to allow the exchange planner to avoid redundant repartitionings. By default, - * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number - * of partitions, same strategy (range or hash), etc). - * - * In order to enable more aggressive optimization, this strict equality check can be relaxed. - * For example, say that the planner needs to repartition all of an operator's children so that - * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children - * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens - * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; - * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` - * [[SinglePartition]]. - * - * The SinglePartition example given above is not particularly interesting; guarantees' real - * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion - * of null-safe partitionings, under which partitionings can specify whether rows whose - * partitioning keys contain null values will be grouped into the same partition or whether they - * will have an unknown / random distribution. If a partitioning does not require nulls to be - * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered - * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot - * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a - * symmetric relation. * - * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows - * produced by `A` could have also been produced by `B`. + * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if + * the [[Partitioning]] only have one partition. Implementations can overwrite this method with + * special logic. */ - def guarantees(other: Partitioning): Boolean = this == other -} - -object Partitioning { - def allCompatible(partitionings: Seq[Partitioning]): Boolean = { - // Note: this assumes transitivity - partitionings.sliding(2).map { - case Seq(a) => true - case Seq(a, b) => - if (a.numPartitions != b.numPartitions) { - assert(!a.compatibleWith(b) && !b.compatibleWith(a)) - false - } else { - a.compatibleWith(b) && b.compatibleWith(a) - } - }.forall(_ == true) - } -} - -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { + def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false } +case class UnknownPartitioning(numPartitions: Int) extends Partitioning + /** * Represents a partitioning where rows are distributed evenly across output partitions * by starting from a random target partition number and distributing rows in a round-robin * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. */ -case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false -} +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) + case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } - - override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - - override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, desiredPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } /** @@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, desiredPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } } @@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) - /** - * Returns true if any `partitioning` of this collection is compatible with - * the given [[Partitioning]]. - */ - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - /** - * Returns true if any `partitioning` of this collection guarantees - * the given [[Partitioning]]. - */ - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } @@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { case BroadcastDistribution(m) if m == mode => true case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning(m) if m == mode => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala deleted file mode 100644 index 5b802ccc637dd..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} - -class PartitioningSuite extends SparkFunSuite { - test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) - // Consider two HashPartitionings that have the same _set_ of hash expressions but which are - // created with different orderings of those expressions: - val partitioningA = HashPartitioning(expressions, 100) - val partitioningB = HashPartitioning(expressions.reverse, 100) - // These partitionings are not considered equal: - assert(partitioningA != partitioningB) - // However, they both satisfy the same clustered distribution: - val distribution = ClusteredDistribution(expressions) - assert(partitioningA.satisfies(distribution)) - assert(partitioningB.satisfies(distribution)) - // These partitionings compute different hashcodes for the same input row: - def computeHashCode(partitioning: HashPartitioning): Int = { - val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) - hashExprProj.apply(InternalRow.empty).hashCode() - } - assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) - // Thus, these partitionings are incompatible: - assert(!partitioningA.compatibleWith(partitioningB)) - assert(!partitioningB.compatibleWith(partitioningA)) - assert(!partitioningA.guarantees(partitioningB)) - assert(!partitioningB.guarantees(partitioningA)) - - // Just to be sure that we haven't cheated by having these methods always return false, - // check that identical partitionings are still compatible with and guarantee each other: - assert(partitioningA === partitioningA) - assert(partitioningA.guarantees(partitioningA)) - assert(partitioningA.compatibleWith(partitioningA)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 787c1cfbfb3d8..82300efc01632 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! - /** Specifies any partition requirements on the input data for this operator. */ + /** + * Specifies the data distribution requirements of all the children for this operator. By default + * it's [[UnspecifiedDistribution]] for each child, which means each child can have any + * distribution. + * + * If an operator overwrites this method, and specifies distribution requirements(excluding + * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark + * guarantees that the outputs of these children will have same number of partitions, so that the + * operator can safely zip partitions of these children's result RDDs. Some operators can leverage + * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify + * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d) + * for its right child, then it's guaranteed that left and right child are co-partitioned by + * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g., + * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child. + */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c8e236be28b42..e3d28388c5470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - * @param requiredDistribution The distribution that is required by the operator - * @param numPartitions Used when the distribution doesn't require a specific number of partitions - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering, desiredPartitions) => - HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - /** * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. @@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // shuffle data when we have more than one children because data generated by // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + val supportsDistribution = requiredChildDistributions.forall { dist => + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + } children.length > 1 && supportsDistribution } @@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } @@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Ensure that the operator's children satisfy their output distribution requirements: + // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val numPartitions = distribution.requiredNumPartitions + .getOrElse(defaultNumPreShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { - case UnspecifiedDistribution => false - case BroadcastDistribution(_) => false + // Get the indexes of children which have specified distribution requirements and need to have + // same number of partitions. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (UnspecifiedDistribution, _) => false + case (_: BroadcastDistribution, _) => false case _ => true - } - if (children.length > 1 - && requiredChildDistributions.exists(requireCompatiblePartitioning) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + }.map(_._2) - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childrenNumPartitions = + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet + + if (childrenNumPartitions.size > 1) { + // Get the number of partitions which is explicitly required by the distributions. + val requiredNumPartitions = { + val numPartitionsSet = childrenIndexes.flatMap { + index => requiredChildDistributions(index).requiredNumPartitions + }.toSet + assert(numPartitionsSet.size <= 1, + s"$operator have incompatible requirements of the number of partitions for its children") + numPartitionsSet.headOption } - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } + val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) - children.zip(requiredChildDistributions).map { - case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) - case _ => ShuffleExchangeExec(targetPartitioning, child) - } + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, distribution), index) if childrenIndexes.contains(index) => + if (child.outputPartitioning.numPartitions == targetNumPartitions) { + child + } else { + val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) + child match { + // If child is an exchange, we replace it with a new one having defaultPartitioning. + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) + case _ => ShuffleExchangeExec(defaultPartitioning, child) + } } - } + + case ((child, _), _) => child } } @@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchangeExec(partitioning, child, _) => - child.children match { - case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator } case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 66e8031bb5191..897a4dae39f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 94405410cce90..2de2f30eb05d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -78,7 +78,7 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d1bd8a7076863..03d1bbf2ab882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -456,7 +456,7 @@ case class CoGroupExec( right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b50642d275ba8..f8b26f5b28cc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext { // do they satisfy the distribution requirements? As a result, we need at least four test cases. private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { - if (outputPlan.children.length > 1 - && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { - val childPartitionings = outputPlan.children.map(_.outputPartitioning) - if (!Partitioning.allCompatible(childPartitionings)) { - fail(s"Partitionings are not compatible: $childPartitionings") + if (outputPlan.children.length > 1) { + val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution) + .filter { + case (_, UnspecifiedDistribution) => false + case (_, _: BroadcastDistribution) => false + case _ => true + }.map(_._1.outputPartitioning) + + if (childPartitionings.map(_.numPartitions).toSet.size > 1) { + fail(s"Partitionings doesn't have same number of partitions: $childPartitionings") } } outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { @@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { - // Consider an operator that requires inputs that are clustered by two expressions (e.g. - // sort merge join where there are multiple columns in the equi-join condition) - val clusteringA = Literal(1) :: Nil - val clusteringB = Literal(2) :: Nil - val distribution = ClusteredDistribution(clusteringA ++ clusteringB) - // Say that the left and right inputs are each partitioned by _one_ of the two join columns: - val leftPartitioning = HashPartitioning(clusteringA, 1) - val rightPartitioning = HashPartitioning(clusteringB, 1) - // Individually, each input's partitioning satisfies the clustering distribution: - assert(leftPartitioning.satisfies(distribution)) - assert(rightPartitioning.satisfies(distribution)) - // However, these partitionings are not compatible with each other, so we still need to - // repartition both inputs prior to performing the join: - assert(!leftPartitioning.compatibleWith(rightPartitioning)) - assert(!rightPartitioning.compatibleWith(leftPartitioning)) - val inputPlan = DummySparkPlan( - children = Seq( - DummySparkPlan(outputPartitioning = leftPartitioning), - DummySparkPlan(outputPartitioning = rightPartitioning) - ), - requiredChildDistribution = Seq(distribution, distribution), - requiredChildOrdering = Seq(Seq.empty, Seq.empty) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { - fail(s"Exchange should have been added:\n$outputPlan") - } - } - test("EnsureRequirements with child partitionings with different numbers of output partitions") { - // This is similar to the previous test, except it checks that partitionings are not compatible - // unless they produce the same number of partitions. val clustering = Literal(1) :: Nil val distribution = ClusteredDistribution(clustering) val inputPlan = DummySparkPlan( @@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + test("EnsureRequirements eliminates Exchange if child has same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { @@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements does not eliminate Exchange with different partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - // Number of partitions differ - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { From 40b983c3b44b6771f07302ce87987fa4716b5ebf Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 8 Jan 2018 23:49:07 +0800 Subject: [PATCH 594/712] [SPARK-22952][CORE] Deprecate stageAttemptId in favour of stageAttemptNumber ## What changes were proposed in this pull request? 1. Deprecate attemptId in StageInfo and add `def attemptNumber() = attemptId` 2. Replace usage of stageAttemptId with stageAttemptNumber ## How was this patch tested? I manually checked the compiler warning info Author: Xianjin YE Closes #20178 from advancedxy/SPARK-22952. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++--- .../apache/spark/scheduler/StageInfo.scala | 4 +- .../spark/scheduler/StatsReportListener.scala | 2 +- .../spark/status/AppStatusListener.scala | 7 +-- .../org/apache/spark/status/LiveEntity.scala | 4 +- .../spark/ui/scope/RDDOperationGraph.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 54 ++++++++++--------- .../execution/ui/SQLAppStatusListener.scala | 2 +- 9 files changed, 51 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2498d4808e91..199937b8c27af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -815,7 +815,8 @@ class DAGScheduler( private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. - val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -1050,7 +1051,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1060,7 +1061,7 @@ class DAGScheduler( val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, + new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1076,7 +1077,7 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1245,7 +1246,7 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { // This task was for the currently running attempt of the stage. Since the task // completed successfully from the perspective of the TaskSetManager, mark it as // no longer pending (the TaskSetManager may consider the task complete even @@ -1324,10 +1325,10 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c513ed36d1680..903e25b7986f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - val attemptId: Int, + @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + def attemptNumber(): Int = attemptId + private[spark] def getStatusString: String = { if (completionTime.isDefined) { if (failureReason.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 3c8cab7504c17..3c7af4f6146fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -79,7 +79,7 @@ class StatsReportListener extends SparkListener with Logging { x => info.completionTime.getOrElse(System.currentTimeMillis()) - x ).getOrElse("-") - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Stage(${info.stageId}, ${info.attemptNumber}); Name: '${info.name}'; " + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + s"Took: $timeTaken msec" } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 487a782e865e8..88b75ddd5993a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -529,7 +529,8 @@ private[spark] class AppStatusListener( } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + val maybeStage = + Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -785,7 +786,7 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber), new Function[(Int, Int), LiveStage]() { override def apply(key: (Int, Int)): LiveStage = new LiveStage() }) @@ -912,7 +913,7 @@ private[spark] class AppStatusListener( private def cleanupTasks(stage: LiveStage): Unit = { val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { - val stageKey = Array(stage.info.stageId, stage.info.attemptId) + val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) .last(stageKey) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 52e83f250d34e..305c2fafa6aac 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -412,14 +412,14 @@ private class LiveStage extends LiveEntity { def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, - new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + new LiveExecutorStageSummary(info.stageId, info.attemptNumber, executorId)) } def toApi(): v1.StageData = { new v1.StageData( status, info.stageId, - info.attemptId, + info.attemptNumber, info.numTasks, activeTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 827a8637b9bd2..948858224d724 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging { // Use a special prefix here to differentiate this cluster from other operation clusters val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId val stageClusterName = s"Stage ${stage.stageId}" + - { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" } + { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" } val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName) var rootNodeCount = 0 diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5e60218c5740b..ff83301d631c4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -263,7 +263,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ - ("Stage Attempt ID" -> stageInfo.attemptId) ~ + ("Stage Attempt ID" -> stageInfo.attemptNumber) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 997c7de8dd02b..b8c84e24c2c3f 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, + stages.head.attemptNumber, + task)) } assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) @@ -213,10 +215,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + val runtime = Array[AnyRef](stages.head.stageId: JInteger, + stages.head.attemptNumber: JInteger, -1L: JLong) assert(Arrays.equals(wrapper.runtime, runtime)) @@ -237,7 +240,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Some(1L), None, true, false, None) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) } check[StageDataWrapper](key(stages.head)) { stage => @@ -254,12 +257,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 val reattempt = newAttempt(s1Tasks.head, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt)) assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) @@ -289,7 +292,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val killed = s1Tasks.drop(1).head killed.finishTime = time killed.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskKilled("killed"), killed, null)) check[JobDataWrapper](1) { job => @@ -311,13 +314,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val denied = newAttempt(killed, nextTaskId()) val denyReason = TaskCommitDenied(1, 1, 1) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, denied)) time += 1 denied.finishTime = time denied.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", denyReason, denied, null)) check[JobDataWrapper](1) { job => @@ -337,7 +340,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a new attempt. val reattempt2 = newAttempt(denied, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt2)) // Succeed all tasks in stage 1. @@ -350,7 +353,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 pending.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", Success, task, s1Metrics)) } @@ -414,13 +417,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, + stages.last.attemptNumber, + task)) } time += 1 s2Tasks.foreach { task => task.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber, "taskType", TaskResultLost, task, null)) } @@ -455,7 +460,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last - val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) time += 1 @@ -466,14 +471,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task)) } time += 1 newS2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, - task, null)) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType", + Success, task, null)) } time += 1 @@ -522,14 +527,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, + j2Stages.last.attemptNumber, task)) } time += 1 j2s2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber, "taskType", Success, task, null)) } @@ -919,13 +925,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val tasks = createTasks(2, Array("1")) tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) // Start a 3rd task. The finished tasks should be deleted. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -934,7 +940,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a 4th task. The first task should be deleted, even if it's still running. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -960,7 +966,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d8adbe7bee13e..73a105266e1c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -99,7 +99,7 @@ class SQLAppStatusListener( // Reset the metrics tracking object for the new attempt. Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptId + metrics.attemptId = event.stageInfo.attemptNumber } } From eed82a0b211352215316ec70dc48aefc013ad0b2 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 8 Jan 2018 13:01:45 -0800 Subject: [PATCH 595/712] [SPARK-22992][K8S] Remove assumption of the DNS domain ## What changes were proposed in this pull request? Remove the use of FQDN to access the driver because it assumes that it's set up in a DNS zone - `cluster.local` which is common but not ubiquitous Note that we already access the in-cluster API server through `kubernetes.default.svc`, so, by extension, this should work as well. The alternative is to introduce DNS zones for both of those addresses. ## How was this patch tested? Unit tests cc vanzin liyinan926 mridulm mccheah Author: foxish Closes #20187 from foxish/cluster.local. --- .../deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala | 2 +- .../k8s/submit/steps/DriverServiceBootstrapStepSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index eb594e4f16ec0..34af7cde6c1a9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -83,7 +83,7 @@ private[spark] class DriverServiceBootstrapStep( .build() val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" + val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) .set("spark.driver.port", driverPort.toString) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala index 006ce2668f8a0..78c8c3ba1afbd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala @@ -85,7 +85,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } @@ -120,7 +120,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } From 4f7e75883436069c2d9028c4cd5daa78e8d59560 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 8 Jan 2018 13:24:08 -0800 Subject: [PATCH 596/712] [SPARK-22912] v2 data source support in MicroBatchExecution ## What changes were proposed in this pull request? Support for v2 data sources in microbatch streaming. ## How was this patch tested? A very basic new unit test on the toy v2 implementation of rate source. Once we have a v1 source fully migrated to v2, we'll need to do more detailed compatibility testing. Author: Jose Torres Closes #20097 from jose-torres/v2-impl. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/v2/DataSourceV2Relation.scala | 10 ++ .../streaming/MicroBatchExecution.scala | 112 ++++++++++++++---- .../streaming/ProgressReporter.scala | 6 +- .../streaming/RateSourceProvider.scala | 10 +- .../execution/streaming/StreamExecution.scala | 4 +- .../streaming/StreamingRelation.scala | 4 +- .../continuous/ContinuousExecution.scala | 4 +- .../ContinuousRateStreamSource.scala | 17 +-- .../sources/RateStreamSourceV2.scala | 31 ++++- .../sql/streaming/DataStreamReader.scala | 25 +++- .../sql/streaming/StreamingQueryManager.scala | 24 ++-- .../streaming/RateSourceV2Suite.scala | 68 +++++++++-- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 2 +- 15 files changed, 241 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae5642..0259c774bbf4a 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -7,3 +7,4 @@ org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7eb99a645001a..cba20dd902007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,6 +35,16 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 9a7a13fcc5806..42240eeb58d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -24,7 +27,10 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -33,10 +39,11 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: Sink, + sink: BaseStreamingSink, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, + extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, @@ -57,6 +64,13 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a + // map as we go to ensure each identical relation gets the same StreamingExecutionRelation + // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical + // plan for the data within that batch. + // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, + // since the existing logical plan has already used those attributes. The per-microbatch + // transformation is responsible for replacing attributes with their final values. val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,19 +78,26 @@ class MicroBatchExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" val source = dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) - if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val reader = source.createMicroBatchReader( + Optional.empty(), // user specified schema + metadataPath, + new DataSourceV2Options(options.asJava)) + nextSourceId += 1 + StreamingExecutionRelation(reader, output)(sparkSession) + }) + case s @ StreamingRelationV2(_, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = v1DataSource.createSource(metadataPath) + assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) } @@ -192,7 +213,8 @@ class MicroBatchExecution( source.getBatch(start, end) } case nonV1Tuple => - throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") + // The V2 API does not have the same edge case requiring getBatch to be called + // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -236,14 +258,27 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { (s, s.getOffset) } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) + + (s, Some(s.getEndOffset)) + } }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) if (dataAvailable) { true @@ -317,6 +352,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -357,7 +394,16 @@ class MicroBatchExecution( s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) + Some(source -> batch.logicalPlan) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + reader.setOffsetRange( + toJava(current), + Optional.of(available.asInstanceOf[OffsetV2])) + logDebug(s"Retrieving data from $reader: $current -> $available") + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -365,15 +411,14 @@ class MicroBatchExecution( // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { + val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, + newData.get(source).map { dataPlan => + assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan + s"${Utils.truncatedString(dataPlan.output, ",")}") + replacements ++= output.zip(dataPlan.output) + dataPlan }.getOrElse { LocalRelation(output, isStreaming = true) } @@ -381,7 +426,7 @@ class MicroBatchExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { + val newAttributePlan = newBatchesPlan transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => @@ -392,6 +437,20 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } + val triggerLogicalPlan = sink match { + case _: Sink => newAttributePlan + case s: MicroBatchWriteSupport => + val writer = s.createMicroBatchWriter( + s"$runId", + currentBatchId, + newAttributePlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + assert(writer.isPresent, "microbatch writer must always be present") + WriteToDataSourceV2(writer.get, newAttributePlan) + case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") + } + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -409,7 +468,12 @@ class MicroBatchExecution( reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) + sink match { + case s: Sink => s.addBatch(currentBatchId, nextBatch) + case s: MicroBatchWriteSupport => + // This doesn't accumulate any data - it just forces execution of the microbatch writer. + nextBatch.collect() + } } } @@ -421,4 +485,8 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1c9043613cb69..d1e5be9c12762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -53,7 +53,7 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, DataFrame] + protected def newData: Map[BaseStreamingSource, LogicalPlan] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] @@ -225,8 +225,8 @@ trait ProgressReporter extends Logging { // // 3. For each source, we sum the metrics of the associated execution plan leaves. // - val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index d02cf882b61ac..66eb0169ac1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -29,12 +29,12 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} @@ -112,7 +112,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister schema: Optional[StructType], checkpointLocation: String, options: DataSourceV2Options): ContinuousReader = { - new ContinuousRateStreamReader(options) + new RateStreamContinuousReader(options) } override def shortName(): String = "rate" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e76bf7b7ca8f..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -163,7 +163,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -418,7 +418,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a9d50e3a112e7..a0ee683a895d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -61,7 +61,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: Source, + source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) extends LeafNode { @@ -92,7 +92,7 @@ case class StreamingRelationV2( sourceName: String, extraOptions: Map[String, String], output: Seq[Attribute], - v1DataSource: DataSource)(session: SparkSession) + v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2843ab13bde2b..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} @@ -174,7 +174,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) - DataSourceV2Relation(newOutput, reader) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c9aa78a5a2e28..b4b21e7d2052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -case class ContinuousRateStreamPartitionOffset( +case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class ContinuousRateStreamReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceV2Options) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats @@ -48,7 +48,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) @@ -86,7 +86,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamReadTask( + RateStreamContinuousReadTask( start.value, start.runTimeMs, i, @@ -101,7 +101,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) } -case class RateStreamReadTask( +case class RateStreamContinuousReadTask( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -109,10 +109,11 @@ case class RateStreamReadTask( rowsPerSecond: Double) extends ReadTask[Row] { override def createDataReader(): DataReader[Row] = - new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) + new RateStreamContinuousDataReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamDataReader( +class RateStreamContinuousDataReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -151,5 +152,5 @@ class RateStreamDataReader( override def close(): Unit = {} override def getOffset(): PartitionOffset = - ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 97bada08bcd2b..c0ed12cec25ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -28,17 +28,38 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamV2Reader(options: DataSourceV2Options) +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceV2Options) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats - val clock = new SystemClock + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } private val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt @@ -111,7 +132,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions - var outTimeMs = startTimeMs + msPerPartitionBetweenRows + var outTimeMs = startTimeMs while (outVal <= endVal) { packedRows.append((outTimeMs, outVal)) outVal += numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e92beecf2c17..52f2e2639cd86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -27,8 +27,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -166,19 +167,31 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) + val v1Relation = ds match { + case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource)) + case _ => None + } ds match { + case s: MicroBatchReadSupport => + val tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( - java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Optional.ofNullable(userSpecifiedSchema.orNull), Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, options) - // Generate the V1 node to catch errors thrown within generation. - StreamingRelation(v1DataSource) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index b508f4406138f..4b27e0d4ef47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,10 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -240,31 +240,35 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - sink match { - case v1Sink: Sink => - new StreamingQueryWrapper(new MicroBatchExecution( + (sink, trigger) match { + case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v1Sink, + v2Sink, trigger, triggerClock, outputMode, + extraOptions, deleteCheckpointOnStop)) - case v2Sink: ContinuousWriteSupport => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) - new StreamingQueryWrapper(new ContinuousExecution( + case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + sink, trigger, triggerClock, outputMode, extraOptions, deleteCheckpointOnStop)) + case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => + throw new AnalysisException( + "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index e11705a227f48..85085d43061bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -18,20 +18,64 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + test("microbatch - numPartitions propagated") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createReadTasks() @@ -39,7 +83,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - set offset") { - val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -48,7 +92,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - infer offsets") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) @@ -69,7 +113,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - predetermined batch size") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) @@ -80,7 +124,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - data read") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { @@ -107,14 +151,14 @@ class RateSourceV2Suite extends StreamTest { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[ContinuousRateStreamReader]) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") } } test("continuous data") { - val reader = new ContinuousRateStreamReader( + val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createReadTasks() @@ -122,17 +166,17 @@ class RateSourceV2Suite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamReadTask => + case t: RateStreamContinuousReadTask => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamDataReader] + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) assert(r.getOffset() == - ContinuousRateStreamPartitionOffset( + RateStreamPartitionOffset( t.partitionIndex, t.partitionIndex + rowIndex * 2, startTimeMs + (rowIndex + 1) * 100)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4b7f0fbe97d4e..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -105,7 +105,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (Source, Offset) + def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) } /** A trait that can be extended when testing a source. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index eda0d8ad48313..9562c10feafe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -61,7 +61,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 68ce792b5857f0291154f524ac651036db868bb9 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 9 Jan 2018 10:15:01 +0800 Subject: [PATCH 597/712] [SPARK-22972] Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc ## What changes were proposed in this pull request? Fix the warning: Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc. ## How was this patch tested? test("SPARK-22972: hive orc source") assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") .equals(HiveSerDe.sourceToSerDe("orc"))) Author: xubo245 <601450868@qq.com> Closes #20165 from xubo245/HiveSerDe. --- .../apache/spark/sql/internal/HiveSerDe.scala | 1 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index b9515ec7bca2a..dac463641cfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -73,6 +73,7 @@ object HiveSerDe { val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc" case s if s.equals("orcfile") => "orc" case s if s.equals("parquetfile") => "parquet" case s if s.equals("avrofile") => "avro" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 17b7d8cfe127e..d556a030e2186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -62,6 +64,33 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { """.stripMargin) } + test("SPARK-22972: hive orc source") { + val tableName = "normal_orc_as_source_hive" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + val tableMetadata = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName)) + assert(tableMetadata.storage.inputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(tableMetadata.storage.outputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(tableMetadata.storage.serde == + Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + } + } + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { val location = Utils.createTempDir() val uri = location.toURI From 849043ce1d28a976659278d29368da0799329db8 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 9 Jan 2018 10:44:21 +0800 Subject: [PATCH 598/712] [SPARK-22990][CORE] Fix method isFairScheduler in JobsTab and StagesTab ## What changes were proposed in this pull request? In current implementation, the function `isFairScheduler` is always false, since it is comparing String with `SchedulingMode` Author: Wang Gengliang Closes #20186 from gengliangwang/isFairScheduler. --- .../src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala | 8 ++++---- .../main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 99eab1b2a27d8..ff1b75e5c5065 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) val killEnabled = parent.killEnabled def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def getSparkUser: String = parent.getSparkUser diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index be05a963f0e68..10b032084ce4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def handleKillRequest(request: HttpServletRequest): Unit = { From f20131dd35939734fe16b0005a086aa72400893b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Jan 2018 11:49:10 +0800 Subject: [PATCH 599/712] [SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner ## What changes were proposed in this pull request? This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was introduced in https://github.com/apache/spark/pull/7821 (July 2015 / Spark 1.5.0+) and is used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec, and AppendColumns. ### Bugs fixed by this patch 1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the implementation "Concatenate the two bitsets together into a single one, taking padding into account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly assuming that if the left row had a non-zero number of fields then the right row would also have at least one field, so it was copying invalid bytes and and treating them as part of the bitset. I'm not sure whether this bug was also present in the original implementation or whether it was introduced in https://github.com/apache/spark/pull/7892 (which fixed another bug in this code). 2. **Incorrect updating of data offsets for null variable-length fields**: after updating the bitsets and copying fixed-length and variable-length data, we need to perform adjustments to the offsets pointing the start of variable length fields's data. The existing code was _conditionally_ adding a fixed offset to correct for the new length of the combined row, but it is unsafe to do this if the variable-length field has a null value: we always represent nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`. ### Why this bug remained latent for so long The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including tests of the cases where one side of the join has no fields and where string-valued fields are null. However, the existing assertions were too weak to uncover this bug: - If a null field has a non-zero value in its fixed-length data slot then this will not cause problems for field accesses because the null-tracking bitmap should still be correct and we will not try to use the incorrect offset for anything. - If the null tracking bitmap is corrupted by joining against a row with no fields then the corruption occurs in field numbers past the actual field numbers contained in the row. Thus valid `isNullAt()` calls will not read the incorrectly-set bits. The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala) but it looks like it also didn't catch this problem because it only tested the bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually comparing the bitsets' bytes. ### Impact of these bugs - This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys. - Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example below). - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead). ### End-to-end test case demonstrating the problem The following query demonstrates how this bug may result in incorrect query results: ```sql set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec create table a as select * from values 1; create table b as select * from values 2; SELECT t3.col1, t1.col1 FROM a t1 CROSS JOIN b t2 CROSS JOIN b t3 ``` This should return `(2, 1)` but instead was returning `(null, 1)`. Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again when performing the final join, causing the final output to have an incorrectly-set null bit for the first field. ## How was this patch tested? Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified that the end-to-end test case which uncovered this now passes. Author: Josh Rosen Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs. --- .../codegen/GenerateUnsafeRowJoiner.scala | 52 +++++++++- .../GenerateUnsafeRowJoinerSuite.scala | 95 ++++++++++++++++++- 2 files changed, 138 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index be5f5a73b5d47..febf7b0c96c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => - val bits = if (bitset1Remainder > 0) { + val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { s"$getLong(obj1, offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { @@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by - // 32 and increment that amount in place. + // 32 and increment that amount in place. However, we need to handle the important special + // case of a null field, in which case the offset should be zero and should not have a + // shift added to it. val shift = if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" @@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" + // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's + // output as a de-facto specification for the internal layout of data. + // + // Null-valued fields will always have a data offset of 0 because + // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's + // position in the fixed-length section of the row. As a result, we must NOT add + // `shift` to the offset for null fields. + // + // We could perform a null-check here by inspecting the null-tracking bitmap, but doing + // so could be expensive and will add significant bloat to the generated code. Instead, + // we'll rely on the invariant "stored offset == 0 for variable-length data type implies + // that the field's value is null." + // + // To establish that this invariant holds, we'll prove that a non-null field can never + // have a stored offset of 0. There are two cases to consider: + // + // 1. The non-null field's data is of non-zero length: reading this field's value + // must read data from the variable-length section of the row, so the stored offset + // will actually be used in address calculation and must be correct. The offsets + // count bytes from the start of the UnsafeRow so these offsets will always be + // non-zero because the storage of the offsets themselves takes up space at the + // start of the row. + // 2. The non-null field's data is of zero length (i.e. its data is empty). In this + // case, we have to worry about the possibility that an arbitrary offset value was + // stored because we never actually read any bytes using this offset and therefore + // would not crash if it was incorrect. The variable-sized data writing paths in + // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with + // no special handling for the case where `numBytes == 0`. Internally, + // setOffsetAndSize computes the offset without taking the size into account. Thus + // the stored offset is the same non-zero offset that would be used if the field's + // dataSize was non-zero (and in (1) above we've shown that case behaves as we + // expect). + // + // Thus it is safe to perform `existingOffset != 0` checks here in the place of + // more expensive null-bit checks. + s""" + |existingOffset = $getLong(buf, $cursor); + |if (existingOffset != 0) { + | $putLong(buf, $cursor, existingOffset + ($shift << 32)); + |} + """.stripMargin } } val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil) + arguments = ("long", "numBytesVariableRow1") :: Nil, + makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 + | long existingOffset; | $updateOffsets | | out.pointTo(buf, sizeInBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index f203f25ad10d4..75c6beeb32150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[GenerateUnsafeRowJoiner]]. @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { testConcat(64, 64, fixed) } + test("rows with all empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8)) + testConcat(schema, row, schema, row) + } + + test("rows with all empty int arrays") { + val schema = StructType(Seq( + StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) + val emptyIntArray = + ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(emptyIntArray, emptyIntArray)) + testConcat(schema, row, schema, row) + } + + test("alternating empty and non-empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo"))) + testConcat(schema, row, schema, row) + } + test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + testConcat(schema1, row1, schema2, row2) + } + + private def testConcat( + schema1: StructType, + row1: UnsafeRow, + schema2: StructType, + row2: UnsafeRow) { // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) - val output = concater.join(row1, row2) + val output: UnsafeRow = concater.join(row1, row2) + + // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures + // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written + // correctly. + val expectedOutput: UnsafeRow = { + val joinedRowProjection = UnsafeProjection.create(mergedSchema) + val joined = new JoinedRow() + joinedRowProjection.apply(joined.apply(row1, row2)) + } // Test everything equals ... for (i <- mergedSchema.indices) { + val dataType = mergedSchema(i).dataType if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row1.get(i, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === - row2.get(i - schema1.size, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } } + + + assert( + expectedOutput.getSizeInBytes == output.getSizeInBytes, + "output isn't same size in bytes as slow path") + + // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in + // case this assertion fails: + val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]] + .take(output.getSizeInBytes) + val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]] + .take(expectedOutput.getSizeInBytes) + + val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields()) + val actualBitset = actualBytes.take(bitsetWidth) + val expectedBitset = expectedBytes.take(bitsetWidth) + assert(actualBitset === expectedBitset, "bitsets were not equal") + + val fixedLengthSize = expectedOutput.numFields() * 8 + val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + if (actualFixedLength !== expectedFixedLength) { + actualFixedLength.grouped(8) + .zip(expectedFixedLength.grouped(8)) + .zip(mergedSchema.fields.toIterator) + .foreach { + case ((actual, expected), field) => + assert(actual === expected, s"Fixed length sections are not equal for field $field") + } + fail("Fixed length sections were not equal") + } + + val variableLengthStart = bitsetWidth + fixedLengthSize + val actualVariableLength = actualBytes.drop(variableLengthStart) + val expectedVariableLength = expectedBytes.drop(variableLengthStart) + assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal") + + assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal") } } From 8486ad419d8f1779e277ec71c39e1516673a83ab Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 21:58:26 -0800 Subject: [PATCH 600/712] [SPARK-21292][DOCS] refreshtable example ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20198 from felixcheung/rrefreshdoc. --- docs/sql-programming-guide.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ccaaf4d5b1fa..72f79d6909ecc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -915,6 +915,14 @@ spark.catalog.refreshTable("my_table") +
    + +{% highlight r %} +refreshTable("my_table") +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -1498,10 +1506,10 @@ that these options will be deprecated in future release as more optimizations ar ## Broadcast Hint for SQL Queries The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
    @@ -1780,7 +1788,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. @@ -2167,7 +2175,7 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. - + ### Incompatible Hive UDF Below are the scenarios in which Hive and Spark generate different results: From 02214b094390e913f52e71d55c9bb8a81c9e7ef9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 22:08:19 -0800 Subject: [PATCH 601/712] [SPARK-21293][SPARKR][DOCS] structured streaming doc update ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20197 from felixcheung/rwadoc. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 2 +- docs/sparkr.md | 2 +- .../structured-streaming-programming-guide.md | 32 +++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2e662424b25f2..feca617c2554c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1042,7 +1042,7 @@ unlink(modelPath) ## Structured Streaming -SparkR supports the Structured Streaming API (experimental). +SparkR supports the Structured Streaming API. You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. diff --git a/docs/sparkr.md b/docs/sparkr.md index 997ea60fb6cf0..6685b585a393a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -596,7 +596,7 @@ The following example shows how to save/load a MLlib model by SparkR. # Structured Streaming -SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) +SparkR supports the Structured Streaming API. Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) # R Function Name Conflicts diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 31fcfabb9cacc..de13e281916db 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -827,8 +827,8 @@ df.isStreaming() {% endhighlight %}
    -{% highlight bash %} -Not available. +{% highlight r %} +isStreaming(df) {% endhighlight %}
    @@ -885,6 +885,19 @@ windowedCounts = words.groupBy( ).count() {% endhighlight %} + +
    +{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
    @@ -959,6 +972,21 @@ windowedCounts = words \ .count() {% endhighlight %} + +
    +{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group + +words <- withWatermark(words, "timestamp", "10 minutes") +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
    From 0959aa581a399279be3f94214bcdffc6a1b6d60a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 9 Jan 2018 16:31:20 +0800 Subject: [PATCH 602/712] [SPARK-23000] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite in Spark 2.3 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ The test suite DataSourceWithHiveMetastoreCatalogSuite of Branch 2.3 always failed in hadoop 2.6 The table `t` exists in `default`, but `runSQLHive` reported the table does not exist. Obviously, Hive client's default database is different. The fix is to clean the environment and use `DEFAULT` as the database. ``` org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' Stacktrace sbt.ForkMain$ForkError: org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:699) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20196 from gatorsmile/testFix. --- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 6 +++++- .../apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7b7f4e0f10210..102f40bacc985 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -823,7 +823,8 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - client.getAllTables("default").asScala.foreach { t => + try { + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).asScala.foreach { index => @@ -837,6 +838,9 @@ private[hive] class HiveClientImpl( logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } + } finally { + runSqlHive("USE default") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 18137e7ea1d63..cf4ce83124d88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -146,6 +146,11 @@ class DataSourceWithHiveMetastoreCatalogSuite 'id cast StringType as 'd2 ).coalesce(1) + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.metadataHive.reset() + } + Seq( "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", From 6a4206ff04746481d7c8e307dfd0d31ff1402555 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Tue, 9 Jan 2018 01:32:48 -0800 Subject: [PATCH 603/712] [SPARK-22998][K8S] Set missing value for SPARK_MOUNTED_CLASSPATH in the executors ## What changes were proposed in this pull request? The environment variable `SPARK_MOUNTED_CLASSPATH` is referenced in the executor's Dockerfile, where its value is added to the classpath of the executor. However, the scheduler backend code missed setting it when creating the executor pods. This PR fixes it. ## How was this patch tested? Unit tested. vanzin Can you help take a look? Thanks! foxish Author: Yinan Li Closes #20193 from liyinan926/master. --- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 ++++- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 066d7e9f70ca5..bcacb3934d36a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -94,6 +94,8 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) + /** * Configure and construct an executor pod with the given parameters. */ @@ -145,7 +147,8 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 884da8aabd880..7cfbe54c95390 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -197,7 +197,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) From f44ba910f58083458e1133502e193a9d6f2bf766 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 9 Jan 2018 21:48:14 +0800 Subject: [PATCH 604/712] [SPARK-16060][SQL] Support Vectorized ORC Reader ## What changes were proposed in this pull request? This PR adds an ORC columnar-batch reader to native `OrcFileFormat`. Since both Spark `ColumnarBatch` and ORC `RowBatch` are used together, it is faster than the current Spark implementation. This replaces the prior PR, #17924. Also, this PR adds `OrcReadBenchmark` to show the performance improvement. ## How was this patch tested? Pass the existing test cases. Author: Dongjoon Hyun Closes #19943 from dongjoon-hyun/SPARK-16060. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../orc/OrcColumnarBatchReader.java | 523 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 75 ++- .../execution/datasources/orc/OrcUtils.scala | 7 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 435 +++++++++++++++ 5 files changed, 1022 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c61f10bb71ad..74949db883f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,11 @@ object SQLConf { .checkValues(Set("hive", "native")) .createWithDefault("native") + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .booleanConf + .createWithDefault(true) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -1183,6 +1188,8 @@ class SQLConf extends Serializable with Logging { def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 0000000000000..5c28d0e6e507a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.*; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + /** + * The default size of batch. We use this value for both ORC and Spark consistently + * because they have different default values like the following. + * + * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 + * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 + */ + public static final int DEFAULT_SIZE = 4 * 1024; + + // ORC File Reader + private Reader reader; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC file. + */ + private int[] requestedColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + /** + * The memory mode of the columnarBatch + */ + private final MemoryMode MEMORY_MODE; + + public OrcColumnarBatchReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } + + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(DEFAULT_SIZE); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + int capacity = DEFAULT_SIZE; + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + } + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); + } + } + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); + + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn); + } else { + putValues(batchSize, field, fromColumn, toColumn); + } + } + } + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int size = data.vector[0].length; + arrayData.reserve(size); + arrayData.putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putBoolean(index, data[index] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putByte(index, (byte)data[index]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putShort(index, (short)data[index]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putInt(index, (int)data[index]); + } + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + toColumn.putLong(index, fromTimestampColumnVector(data, index)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putFloat(index, (float)data[index]); + } + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { + arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]); + toColumn.putArray(index, pos, data.length[index]); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[index]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[index] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[index]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[index]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[index]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[index]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, index)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[index]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[index]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]); + toColumn.putArray(index, pos, vector.length[index]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[index]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(bytes.length); + arrayData.putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debce..b8bacfa1838ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -118,6 +118,13 @@ class OrcFileFormat } } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], @@ -139,6 +146,11 @@ class OrcFileFormat } } + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -146,8 +158,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, requiredSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty @@ -155,29 +173,46 @@ class OrcFileFormat val requestedColIds = requestedColIdsOrEmptyFile.get assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, requestedColIds.filter(_ != -1).sorted.mkString(",")) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) - val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - - if (partitionSchema.length == 0) { - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val batchReader = + new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + reader.getSchema, + requestedColIds, + requiredSchema.fields, + partitionSchema, + file.partitionValues) + + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + iter.asInstanceOf[Iterator[InternalRow]] } else { - val joinedRow = new JoinedRow() - iter.map(value => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b03ee06d04a16..13a23996f4ade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -80,11 +80,8 @@ object OrcUtils extends Logging { isCaseSensitive: Boolean, dataSchema: StructType, requiredSchema: StructType, - file: Path, + reader: Reader, conf: Configuration): Option[Array[Int]] = { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) val orcFieldNames = reader.getSchema.getFieldNames.asScala if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala new file mode 100644 index 0000000000000..37ed846acd1eb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure ORC read performance. + * + * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. + */ +// scalastyle:off line.size.limit +object OrcReadBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("OrcReadBenchmark") + .config(conf) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirORC = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).orc(dirORC) + } else { + df.write.orc(dirORC) + } + + spark.read.format(NATIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("nativeOrcTable") + spark.read.format(HIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("hiveOrcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1192 / 1221 13.2 75.8 1.0X + Native ORC Vectorized 161 / 170 97.5 10.3 7.4X + Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1287 / 1333 12.2 81.8 1.0X + Native ORC Vectorized 164 / 172 95.6 10.5 7.8X + Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1304 / 1388 12.1 82.9 1.0X + Native ORC Vectorized 227 / 240 69.3 14.4 5.7X + Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1331 / 1357 11.8 84.6 1.0X + Native ORC Vectorized 289 / 297 54.4 18.4 4.6X + Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1410 / 1428 11.2 89.7 1.0X + Native ORC Vectorized 328 / 335 48.0 20.8 4.3X + Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1467 / 1485 10.7 93.3 1.0X + Native ORC Vectorized 402 / 411 39.1 25.6 3.6X + Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + */ + sqlBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2729 / 2744 3.8 260.2 1.0X + Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X + Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Read data column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read partition column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read both columns - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X + Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X + Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X + Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X + Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X + Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X + Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X + Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X + Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1325 / 1328 7.9 126.4 1.0X + Native ORC Vectorized 320 / 330 32.8 30.5 4.1X + Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2553 / 2554 4.1 243.4 1.0X + Native ORC Vectorized 953 / 954 11.0 90.9 2.7X + Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + + String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2389 / 2408 4.4 227.8 1.0X + Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X + Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + + String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1295 / 1311 8.1 123.5 1.0X + Native ORC Vectorized 449 / 457 23.4 42.8 2.9X + Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1103 / 1124 1.0 1052.0 1.0X + Native ORC Vectorized 92 / 100 11.4 87.9 12.0X + Hive built-in ORC 383 / 390 2.7 365.4 2.9X + + SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2245 / 2250 0.5 2141.0 1.0X + Native ORC Vectorized 157 / 165 6.7 150.2 14.3X + Hive built-in ORC 587 / 593 1.8 559.4 3.8X + + SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 3343 / 3350 0.3 3188.3 1.0X + Native ORC Vectorized 265 / 280 3.9 253.2 12.6X + Hive built-in ORC 828 / 842 1.3 789.8 4.0X + */ + sqlBenchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + repeatedStringScanBenchmark(1024 * 1024 * 10) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } +} +// scalastyle:on line.size.limit From 2250cb75b99d257e698fe5418a51d8cddb4d5104 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 Jan 2018 21:58:55 +0800 Subject: [PATCH 605/712] [SPARK-22981][SQL] Fix incorrect results of Casting Struct to String ## What changes were proposed in this pull request? This pr fixed the issue when casting structs into strings; ``` scala> val df = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") scala> df.write.saveAsTable("t") scala> sql("SELECT CAST(a AS STRING) FROM t").show +-------------------+ | a| +-------------------+ |[0,1,1800000001,61]| |[0,2,1800000001,62]| +-------------------+ ``` This pr modified the result into; ``` +------+ | a| +------+ |[1, a]| |[2, b]| +------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20176 from maropu/SPARK-22981. --- .../spark/sql/catalyst/expressions/Cast.scala | 71 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 16 +++++ 2 files changed, 87 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2de4c8e30bec..f21aa1e9e3135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -259,6 +259,29 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -732,6 +755,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -765,6 +823,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1445bb8a97d40..5b25bdf907c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -906,4 +906,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } } From 96ba217a06fbe1dad703447d7058cb7841653861 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 10:15:27 +0800 Subject: [PATCH 606/712] [SPARK-23005][CORE] Improve RDD.take on small number of partitions ## What changes were proposed in this pull request? In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%: `(1.5 * num * partsScanned / buf.size).toInt` However, when the number is small, the result of `.toInt` is not what we want. E.g, 2.9 will become 2, which should be 3. Use Math.ceil to fix the problem. Also clean up the code in RDD.scala. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20200 from gengliangwang/Take. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 27 +++++++++---------- .../spark/sql/execution/SparkPlan.scala | 5 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc925362..7859781e98223 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - @transient var name: String = null + @transient var name: String = _ /** Assign a name to this RDD */ def setName(_name: String): this.type = { @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag]( // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null + private var dependencies_ : Seq[Dependency[_]] = _ + @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag]( private[spark] def getNarrowAncestors: Seq[RDD[_]] = { val ancestors = new mutable.HashSet[RDD[_]] - def visit(rdd: RDD[_]) { + def visit(rdd: RDD[_]): Unit = { val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) val narrowParents = narrowDependencies.map(_.rdd) val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) @@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) + var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag]( def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } - (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + partitions.indices.iterator.flatMap(i => collectPartition(i)) } /** @@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag]( // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L + val left = num - buf.size if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag]( if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } - val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) @@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag]( // an RDD and its parent in every batch, in which case the parent may never be checkpointed // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). private val checkpointAllMarkedAncestors = - Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) - .map(_.toBoolean).getOrElse(false) + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean) /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { @@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag]( } /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ - protected[spark] def parent[U: ClassTag](j: Int) = { + protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = { dependencies(j).rdd.asInstanceOf[RDD[U]] } @@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies() { + protected def clearDependencies(): Unit = { dependencies_ = null } @@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag]( val lastDepStrings = debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) - (frontDepStrings ++ lastDepStrings) + frontDepStrings ++ lastDepStrings } } // The first RDD in the dependency stack has no parents, so no need for a +- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 82300efc01632..398758a3331b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + val left = n - buf.size + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } From 6f169ca9e1444fe8fd1ab6f3fbf0a8be1670f1b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 10:20:34 +0800 Subject: [PATCH 607/712] [MINOR] fix a typo in BroadcastJoinSuite ## What changes were proposed in this pull request? `BroadcastNestedLoopJoinExec` should be `BroadcastHashJoinExec` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20202 from cloud-fan/typo. --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6da46ea3480b3..0bcd54e1fceab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -318,7 +318,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { case b: BroadcastNestedLoopJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) - case b: BroadcastNestedLoopJoinExec => + case b: BroadcastHashJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => From 7bcc2666810cefc85dfa0d6679ac7a0de9e23154 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:00:07 +0900 Subject: [PATCH 608/712] [SPARK-23018][PYTHON] Fix createDataFrame from Pandas timestamp series assignment ## What changes were proposed in this pull request? This fixes createDataFrame from Pandas to only assign modified timestamp series back to a copied version of the Pandas DataFrame. Previously, if the Pandas DataFrame was only a reference (e.g. a slice of another) each series will still get assigned back to the reference even if it is not a modified timestamp column. This caused the following warning "SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame." ## How was this patch tested? existing tests Author: Bryan Cutler Closes #20213 from BryanCutler/pyspark-createDataFrame-copy-slice-warn-SPARK-23018. --- python/pyspark/sql/session.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6052fa9e84096..3e4574729a631 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -459,21 +459,23 @@ def _convert_from_pandas(self, pdf, schema, timezone): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if not copied and s is not pdf[field.name]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s else: for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) - if not copied and s is not pdf[column]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) From e5998372487af20114e160264a594957344ff433 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:55:24 +0900 Subject: [PATCH 609/712] [SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas ## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009. --- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 3e4574729a631..604021c1f45cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -648,7 +648,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13576ff57001b..80a94a91a87b3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3532,6 +3532,15 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): From edf0a48c2ec696b92ed6a96dcee6eeb1a046b20b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 15:01:11 +0800 Subject: [PATCH 610/712] [SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel ## What changes were proposed in this pull request? This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.** For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel. Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method. It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads. The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread. The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`. After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded. This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom. This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+. ## How was this patch tested? This fix was tested manually against a workload which non-deterministically hit this bug. Author: Josh Rosen Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 37 +++++++++++-------- .../shuffle/IndexShuffleBlockResolver.scala | 21 +++++++++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5c..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => - pipe.sink().close() - source.close() - throw e - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..266ee42e39cca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8 + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), From eaac60a1e20e29084b7151ffca964cfaa5ba99d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 15:16:27 +0800 Subject: [PATCH 611/712] [SPARK-16060][SQL][FOLLOW-UP] add a wrapper solution for vectorized orc reader ## What changes were proposed in this pull request? This is mostly from https://github.com/apache/spark/pull/13775 The wrapper solution is pretty good for string/binary type, as the ORC column vector doesn't keep bytes in a continuous memory region, and has a significant overhead when copying the data to Spark columnar batch. For other cases, the wrapper solution is almost same with the current solution. I think we can treat the wrapper solution as a baseline and keep improving the writing to Spark solution. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20205 from cloud-fan/orc. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../datasources/orc/OrcColumnVector.java | 251 ++++++++++++++++++ .../orc/OrcColumnarBatchReader.java | 106 ++++++-- .../datasources/orc/OrcFileFormat.scala | 6 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 236 ++++++++++------ 5 files changed, 490 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 74949db883f7a..36e802a9faa6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -391,6 +391,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") + .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + + "vectorized ORC reader.") + .internal() + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 0000000000000..f94c55d860304 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + final private boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5c28d0e6e507a..36fdf2bdf84d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -51,13 +51,13 @@ public class OrcColumnarBatchReader extends RecordReader { /** - * The default size of batch. We use this value for both ORC and Spark consistently - * because they have different default values like the following. + * The default size of batch. We use this value for ORC reader to make it consistent with Spark's + * columnar batch, because their default batch sizes are different like the following: * * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 */ - public static final int DEFAULT_SIZE = 4 * 1024; + private static final int DEFAULT_SIZE = 4 * 1024; // ORC File Reader private Reader reader; @@ -82,13 +82,18 @@ public class OrcColumnarBatchReader extends RecordReader { // Writable column vectors of the result columnar batch. private WritableColumnVector[] columnVectors; - /** - * The memory mode of the columnarBatch - */ + // The wrapped ORC column vectors. It should be null if `copyToSpark` is true. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // The memory mode of the columnarBatch private final MemoryMode MEMORY_MODE; - public OrcColumnarBatchReader(boolean useOffHeap) { + // Whether or not to copy the ORC columnar batch to Spark columnar batch. + private final boolean copyToSpark; + + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + this.copyToSpark = copyToSpark; } @@ -167,27 +172,61 @@ public void initBatch( } int capacity = DEFAULT_SIZE; - if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); - } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); - } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); + if (copyToSpark) { + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - } - // Initialize the missing columns once. - for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); - columnVectors[i].setIsConstant(); + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + } else { + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); } } @@ -196,17 +235,26 @@ public void initBatch( * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. */ private boolean nextBatch() throws IOException { - for (WritableColumnVector vector : columnVectors) { - vector.reset(); - } - columnarBatch.setNumRows(0); - recordReader.nextBatch(batch); int batchSize = batch.size; if (batchSize == 0) { return false; } columnarBatch.setNumRows(batchSize); + + if (!copyToSpark) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } + + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + for (int i = 0; i < requiredFields.length; i++) { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b8bacfa1838ae..2dd314d165348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -150,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -183,8 +185,8 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { - val batchReader = - new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 37ed846acd1eb..bf6efa7c4c08c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -86,7 +86,7 @@ object OrcReadBenchmark { } def numericScanBenchmark(values: Int, dataType: DataType): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -95,61 +95,73 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1192 / 1221 13.2 75.8 1.0X - Native ORC Vectorized 161 / 170 97.5 10.3 7.4X - Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + Native ORC MR 1135 / 1171 13.9 72.2 1.0X + Native ORC Vectorized 152 / 163 103.4 9.7 7.5X + Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X + Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1287 / 1333 12.2 81.8 1.0X - Native ORC Vectorized 164 / 172 95.6 10.5 7.8X - Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + Native ORC MR 1182 / 1244 13.3 75.2 1.0X + Native ORC Vectorized 145 / 156 108.7 9.2 8.2X + Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X + Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1304 / 1388 12.1 82.9 1.0X - Native ORC Vectorized 227 / 240 69.3 14.4 5.7X - Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + Native ORC MR 1271 / 1271 12.4 80.8 1.0X + Native ORC Vectorized 206 / 212 76.3 13.1 6.2X + Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X + Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1331 / 1357 11.8 84.6 1.0X - Native ORC Vectorized 289 / 297 54.4 18.4 4.6X - Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + Native ORC MR 1344 / 1355 11.7 85.4 1.0X + Native ORC Vectorized 258 / 268 61.0 16.4 5.2X + Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X + Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1410 / 1428 11.2 89.7 1.0X - Native ORC Vectorized 328 / 335 48.0 20.8 4.3X - Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + Native ORC MR 1333 / 1352 11.8 84.8 1.0X + Native ORC Vectorized 310 / 324 50.7 19.7 4.3X + Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X + Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1467 / 1485 10.7 93.3 1.0X - Native ORC Vectorized 402 / 411 39.1 25.6 3.6X - Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + Native ORC MR 1408 / 1585 11.2 89.5 1.0X + Native ORC Vectorized 359 / 368 43.8 22.8 3.9X + Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X + Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X */ - sqlBenchmark.run() + benchmark.run() } } } @@ -176,19 +188,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2729 / 2744 3.8 260.2 1.0X - Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X - Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + Native ORC MR 2566 / 2592 4.1 244.7 1.0X + Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X + Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X + Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X */ benchmark.run() } @@ -205,63 +224,84 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) - benchmark.addCase("Read data column - Native ORC MR") { _ => + benchmark.addCase("Data column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + benchmark.addCase("Data column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read data column - Hive built-in ORC") { _ => + benchmark.addCase("Data column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } - benchmark.addCase("Read partition column - Native ORC MR") { _ => + benchmark.addCase("Partition column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } - benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Hive built-in ORC") { _ => spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() } - benchmark.addCase("Read both columns - Native ORC MR") { _ => + benchmark.addCase("Both columns - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + benchmark.addCase("Both columns - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + benchmark.addCase("Both column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Hive built-in ORC") { _ => spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X - Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X - Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X - Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X - Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X - Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X - Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X - Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X - Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X + Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X + Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X + Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X + Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X + Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X + Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X + Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X + Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X + Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X + Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X + Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X */ benchmark.run() } @@ -287,19 +327,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1325 / 1328 7.9 126.4 1.0X - Native ORC Vectorized 320 / 330 32.8 30.5 4.1X - Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + Native ORC MR 1271 / 1278 8.3 121.2 1.0X + Native ORC Vectorized 200 / 212 52.4 19.1 6.4X + Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X + Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X */ benchmark.run() } @@ -331,32 +378,42 @@ object OrcReadBenchmark { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2553 / 2554 4.1 243.4 1.0X - Native ORC Vectorized 953 / 954 11.0 90.9 2.7X - Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + Native ORC MR 2394 / 2886 4.4 228.3 1.0X + Native ORC Vectorized 699 / 729 15.0 66.7 3.4X + Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X + Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2389 / 2408 4.4 227.8 1.0X - Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X - Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + Native ORC MR 2234 / 2255 4.7 213.1 1.0X + Native ORC Vectorized 854 / 869 12.3 81.4 2.6X + Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X + Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1295 / 1311 8.1 123.5 1.0X - Native ORC Vectorized 449 / 457 23.4 42.8 2.9X - Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + Native ORC MR 1166 / 1202 9.0 111.2 1.0X + Native ORC Vectorized 338 / 345 31.1 32.2 3.5X + Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X + Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X */ benchmark.run() } @@ -364,7 +421,7 @@ object OrcReadBenchmark { } def columnsBenchmark(values: Int, width: Int): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -376,43 +433,52 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1103 / 1124 1.0 1052.0 1.0X - Native ORC Vectorized 92 / 100 11.4 87.9 12.0X - Hive built-in ORC 383 / 390 2.7 365.4 2.9X + Native ORC MR 1050 / 1053 1.0 1001.1 1.0X + Native ORC Vectorized 95 / 101 11.0 90.9 11.0X + Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X + Hive built-in ORC 348 / 358 3.0 331.8 3.0X - SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2245 / 2250 0.5 2141.0 1.0X - Native ORC Vectorized 157 / 165 6.7 150.2 14.3X - Hive built-in ORC 587 / 593 1.8 559.4 3.8X + Native ORC MR 2099 / 2108 0.5 2002.1 1.0X + Native ORC Vectorized 179 / 187 5.8 171.1 11.7X + Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X + Hive built-in ORC 562 / 581 1.9 535.9 3.7X - SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 3343 / 3350 0.3 3188.3 1.0X - Native ORC Vectorized 265 / 280 3.9 253.2 12.6X - Hive built-in ORC 828 / 842 1.3 789.8 4.0X + Native ORC MR 3221 / 3246 0.3 3071.4 1.0X + Native ORC Vectorized 312 / 322 3.4 298.0 10.3X + Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X + Hive built-in ORC 815 / 824 1.3 777.3 4.0X */ - sqlBenchmark.run() + benchmark.run() } } } From 70bcc9d5ae33d6669bb5c97db29087ccead770fb Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 9 Jan 2018 23:32:47 -0800 Subject: [PATCH 612/712] [SPARK-22993][ML] Clarify HasCheckpointInterval param doc ## What changes were proposed in this pull request? Add a note to the `HasCheckpointInterval` parameter doc that clarifies that this setting is ignored when no checkpoint directory has been set on the spark context. ## How was this patch tested? No tests necessary, just a doc update. Author: sethah Closes #20188 from sethah/als_checkpoint_doc. --- R/pkg/R/mllib_recommendation.R | 2 ++ R/pkg/R/mllib_tree.R | 6 ++++++ .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- python/pyspark/ml/param/_shared_params_code_gen.py | 5 +++-- python/pyspark/ml/param/shared.py | 4 ++-- 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index fa794249085d7..5441c4a4022a9 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -48,6 +48,8 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @param numUserBlocks number of user blocks used to parallelize computation (> 0). #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 89a58bf0aadae..4e5ddf22ee16d 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -161,6 +161,8 @@ print.summary.decisionTree <- function(x) { #' >= 1. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -382,6 +384,8 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -595,6 +599,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a5d57a15317e6..6ad44af9ef7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -63,7 +63,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + - "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), + "every 10 iterations. Note: this setting will be ignored if the checkpoint directory " + + "is not set in the SparkContext", + isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an error). More " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f18..be8b2f273164b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -282,10 +282,10 @@ trait HasOutputCols extends Params { trait HasCheckpointInterval extends Params { /** - * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index d55d209d09398..1d0f60acc6983 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,8 +121,9 @@ def get$Name(self): ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, - "TypeConverters.toInt"), + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + + "this setting will be ignored if the checkpoint directory is not set in the SparkContext.", + None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index e5c5ddfba6c1f..813f7a59f3fd1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -281,10 +281,10 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() From f340b6b3066033d40b7e163fd5fb68e9820adfb1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 00:45:47 -0800 Subject: [PATCH 613/712] [SPARK-22997] Add additional defenses against use of freed MemoryBlocks ## What changes were proposed in this pull request? This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks. ## How was this patch tested? New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic. Author: Josh Rosen Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator. --- .../unsafe/memory/HeapMemoryAllocator.java | 35 +++++++++---- .../spark/unsafe/memory/MemoryBlock.java | 21 +++++++- .../unsafe/memory/UnsafeMemoryAllocator.java | 11 ++++ .../spark/unsafe/PlatformUtilSuite.java | 50 ++++++++++++++++++- .../spark/memory/TaskMemoryManager.java | 13 ++++- .../spark/memory/TaskMemoryManagerSuite.java | 29 +++++++++++ 6 files changed, 146 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index cc9cc429643ad..3acfe3696cb1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -31,8 +31,7 @@ public class HeapMemoryAllocator implements MemoryAllocator { @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap<>(); + private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; @@ -49,13 +48,14 @@ private boolean shouldPool(long size) { public MemoryBlock allocate(long size) throws OutOfMemoryError { if (shouldPool(size)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(size); if (pool != null) { while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); + final WeakReference arrayReference = pool.pop(); + final long[] array = arrayReference.get(); + if (array != null) { + assert (array.length * 8L >= size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -76,18 +76,35 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { + assert (memory.obj != null) : + "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to null out its reference to the long[] array. + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); + if (shouldPool(size)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(size); if (pool == null) { pool = new LinkedList<>(); bufferPoolsBySize.put(size, pool); } - pool.add(new WeakReference<>(memory)); + pool.add(new WeakReference<>(array)); } } else { // Do nothing diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index cd1d378bc1470..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -26,6 +26,25 @@ */ public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; + + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + private final long length; /** @@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation { * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, * which lives in a different package. */ - public int pageNumber = -1; + public int pageNumber = NO_PAGE_NUMBER; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 55bcdf1ed7b06..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to reset its pointer. + memory.offset = 0; + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4b141339ec816..62854837b05ed 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -62,6 +62,52 @@ public void overlappingCopyMemory() { } } + @Test + public void onHeapMemoryAllocatorPoolingReUsesLongArrays() { + MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject1 = block1.getBaseObject(); + MemoryAllocator.HEAP.free(block1); + MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject2 = block2.getBaseObject(); + Assert.assertSame(baseObject1, baseObject2); + MemoryAllocator.HEAP.free(block2); + } + + @Test + public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + Assert.assertNotNull(block.getBaseObject()); + MemoryAllocator.HEAP.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test + public void freeingOffHeapMemoryBlockResetsOffset() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + Assert.assertNull(block.getBaseObject()); + Assert.assertNotEquals(0, block.getBaseOffset()); + MemoryAllocator.UNSAFE.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test(expected = AssertionError.class) + public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + MemoryAllocator.HEAP.free(block); + MemoryAllocator.HEAP.free(block); + } + + @Test(expected = AssertionError.class) + public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + MemoryAllocator.UNSAFE.free(block); + MemoryAllocator.UNSAFE.free(block); + } + @Test public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); @@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object onheap1BaseObject = onheap1.getBaseObject(); + long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + Platform.getByte(onheap1BaseObject, onheap1BaseOffset), MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index e8d3730daa7a4..632d718062212 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -321,8 +321,12 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; assert(allocatedPages.get(page.pageNumber)); pageTable[page.pageNumber] = null; synchronized (this) { @@ -332,6 +336,10 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -358,7 +366,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } @@ -424,6 +432,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 46b0516e36141..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() { Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } + @Test + public void freeingPageSetsPageNumberToSpecialConstant() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + c.freePage(dataPage); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + } + + @Test(expected = AssertionError.class) + public void freeingPageDirectlyInAllocatorTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + MemoryAllocator.HEAP.free(dataPage); + } + + @Test(expected = AssertionError.class) + public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256); + manager.freePage(dataPage, c); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); From 344e3aab87178e45957333479a07e07f202ca1fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 09:44:30 -0800 Subject: [PATCH 614/712] [SPARK-23019][CORE] Wait until SparkContext.stop() finished in SparkLauncherSuite ## What changes were proposed in this pull request? In current code ,the function `waitFor` call https://github.com/apache/spark/blob/cfcd746689c2b84824745fa6d327ffb584c7a17d/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java#L155 only wait until DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. https://github.com/apache/spark/blob/1c9f95cb771ac78775a77edd1abfeb2d8ae2a124/core/src/main/scala/org/apache/spark/SparkContext.scala#L1924 Thus, in the Jenkins test https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-maven-hadoop-2.6/ , `JdbcRDDSuite` failed because the previous test `SparkLauncherSuite` exit before SparkContext.stop() is finished. To repo: ``` $ build/sbt > project core > testOnly *SparkLauncherSuite *JavaJdbcRDDSuite ``` To Fix: Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM in SparkLauncherSuite. Can' come up with any better solution for now. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20221 from gengliangwang/SPARK-23019. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index c2261c204cd45..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; @@ -133,6 +134,10 @@ public void testInProcessLauncher() throws Exception { p.put(e.getKey(), e.getValue()); } System.setProperties(p); + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + TimeUnit.MILLISECONDS.sleep(500); } } From 9b33dfc408de986f4203bb0ac0c3f5c56effd69d Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 10 Jan 2018 14:25:04 -0800 Subject: [PATCH 615/712] [SPARK-22951][SQL] fix aggregation after dropDuplicates on empty data frames ## What changes were proposed in this pull request? (courtesy of liancheng) Spark SQL supports both global aggregation and grouping aggregation. Global aggregation always return a single row with the initial aggregation state as the output, even there are zero input rows. Spark implements this by simply checking the number of grouping keys and treats an aggregation as a global aggregation if it has zero grouping keys. However, this simple principle drops the ball in the following case: ```scala spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show() // +---+ // | c | // +---+ // | 1 | // +---+ ``` The reason is that: 1. `df.dropDuplicates()` is roughly translated into something equivalent to: ```scala val allColumns = df.columns.map { col } df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*) ``` This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`. 2. `spark.emptyDataFrame` contains zero columns and zero rows. Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing transformation roughly equivalent to the following one: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy().agg(Map.empty[String, String]) ``` The above transformation is confusing because the resulting aggregate operator contains no grouping keys (because `emptyDataFrame` contains no columns), and gets recognized as a global aggregation. As a result, Spark SQL allocates a single row filled by the initial aggregation state and uses it as the output, and returns a wrong result. To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by appending a literal `1` to the grouping key list of the resulting `Aggregate` operator when the input plan contains zero output columns. In this way, `spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping aggregation, roughly depicted as: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String]) ``` Which is now properly treated as a grouping aggregation and returns the correct answer. ## How was this patch tested? New unit tests added Author: Feng Liu Closes #20174 from liufengdb/fix-duplicate. --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++++++- .../optimizer/ReplaceOperatorSuite.scala | 10 +++++++- .../spark/sql/DataFrameAggregateSuite.scala | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index df0af8264a329..c794ba8619322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1222,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } - Aggregate(keys, aggCols, child) + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + Aggregate(nonemptyKeys, aggCols, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 0fa1aaeb9e164..e9701ffd2c54b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("add one grouping key if necessary when replace Deduplicate with Aggregate") { + val input = LocalRelation() + val query = Deduplicate(Seq.empty, input) // dropDuplicates() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input) + comparePlans(optimized, correctAnswer) + } + test("don't replace streaming Deduplicate") { val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b297..e7776e36702ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), Row(null, null, null, null, null)) @@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + Seq(true, false).foreach { codegen => + test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + + s"results when codegen is enabled: $codegen") { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) + } + } + } } From a6647ffbf7a312a3e119a9beef90880cc915aa60 Mon Sep 17 00:00:00 2001 From: Mingjie Tang Date: Thu, 11 Jan 2018 11:51:03 +0800 Subject: [PATCH 616/712] [SPARK-22587] Spark job fails if fs.defaultFS and application jar are different url ## What changes were proposed in this pull request? Two filesystems comparing does not consider the authority of URI. This is specific for WASB file storage system, where userInfo is honored to differentiate filesystems. For example: wasbs://user1xyz.net, wasbs://user2xyz.net would consider as two filesystem. Therefore, we have to add the authority to compare two filesystem, and two filesystem with different authority can not be the same FS. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Mingjie Tang Closes #19885 from merlintang/EAR-7377. --- .../org/apache/spark/deploy/yarn/Client.scala | 24 +++++++++++--- .../spark/deploy/yarn/ClientSuite.scala | 33 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15328d08b3b5c..8cd3cd9746a3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1421,15 +1421,20 @@ private object Client extends Logging { } /** - * Return whether the two file systems are the same. + * Return whether two URI represent file system are the same */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() + private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + val srcAuthority = srcUri.getAuthority() + val dstAuthority = dstUri.getAuthority() + if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) { + return false + } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() @@ -1447,6 +1452,17 @@ private object Client extends Logging { } Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() + + } + + /** + * Return whether the two file systems are the same. + */ + protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + + compareUri(srcUri, dstUri) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 9d5f5eb621118..7fa597167f3f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -357,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers { sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName))) } + private val matching = Seq( + ("files URI match test1", "file:///file1", "file:///file2"), + ("files URI match test2", "file:///c:file1", "file://c:file2"), + ("files URI match test3", "file://host/file1", "file://host/file2"), + ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"), + ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1") + ) + + matching.foreach { t => + test(t._1) { + assert(Client.compareUri(new URI(t._2), new URI(t._3)), + s"No match between ${t._2} and ${t._3}") + } + } + + private val unmatching = Seq( + ("files URI unmatch test1", "file:///file1", "file://host/file2"), + ("files URI unmatch test2", "file://host/file1", "file:///file2"), + ("files URI unmatch test3", "file://host/file1", "file://host2/file2"), + ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"), + ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"), + ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"), + ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"), + ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2") + ) + + unmatching.foreach { t => + test(t._1) { + assert(!Client.compareUri(new URI(t._2), new URI(t._3)), + s"match between ${t._2} and ${t._3}") + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 87c98de8b23f0e978958fc83677fdc4c339b7e6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 18:17:34 +0800 Subject: [PATCH 617/712] [SPARK-23001][SQL] Fix NullPointerException when DESC a database with NULL description ## What changes were proposed in this pull request? When users' DB description is NULL, users might hit `NullPointerException`. This PR is to fix the issue. ## How was this patch tested? Added test cases Author: gatorsmile Closes #20215 from gatorsmile/SPARK-23001. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 102f40bacc985..4b923f5235a90 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -330,7 +330,7 @@ private[hive] class HiveClientImpl( Option(client.getDatabase(dbName)).map { d => CatalogDatabase( name = d.getName, - description = d.getDescription, + description = Option(d.getDescription).getOrElse(""), locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 2e35fdeba464d..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -107,4 +107,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { .filter(_.contains("Num Buckets")).head assert(bucketString.contains("10")) } + + test("SPARK-23001: NullPointerException when running desc database") { + val catalog = newBasicCatalog() + catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) + assert(catalog.getDatabase("dbWithNullDesc").description == "") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 94473a08dd317..ff90e9dda5f7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -163,6 +163,15 @@ class VersionsSuite extends SparkFunSuite with Logging { client.createDatabase(tempDB, ignoreIfExists = true) } + test(s"$version: createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + test(s"$version: setCurrentDatabase") { client.setCurrentDatabase("default") } From 1c70da3bfbb4016e394de2c73eb0db7cdd9a6968 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 19:41:48 +0800 Subject: [PATCH 618/712] [SPARK-20657][CORE] Speed up rendering of the stages page. There are two main changes to speed up rendering of the tasks list when rendering the stage page. The first one makes the code only load the tasks being shown in the current page of the tasks table, and information related to only those tasks. One side-effect of this change is that the graph that shows task-related events now only shows events for the tasks in the current page, instead of the previously hardcoded limit of "events for the first 1000 tasks". That ends up helping with readability, though. To make sorting efficient when using a disk store, the task wrapper was extended to include many new indices, one for each of the sortable columns in the UI, and metrics for which quantiles are calculated. The second changes the way metric quantiles are calculated for stages. Instead of using the "Distribution" class to process data for all task metrics, which requires scanning all tasks of a stage, the code now uses the KVStore "skip()" functionality to only read tasks that contain interesting information for the quantiles that are desired. This is still not cheap; because there are many metrics that the UI and API track, the code needs to scan the index for each metric to gather the information. Savings come mainly from skipping deserialization when using the disk store, but the in-memory code also seems to be faster than before (most probably because of other changes in this patch). To make subsequent calls faster, some quantiles are cached in the status store. This makes UIs much faster after the first time a stage has been loaded. With the above changes, a lot of code in the UI layer could be simplified. Author: Marcelo Vanzin Closes #20013 from vanzin/SPARK-20657. --- .../apache/spark/util/kvstore/LevelDB.java | 1 + .../spark/status/AppStatusListener.scala | 57 +- .../apache/spark/status/AppStatusStore.scala | 389 +++++--- .../apache/spark/status/AppStatusUtils.scala | 68 ++ .../org/apache/spark/status/LiveEntity.scala | 344 ++++--- .../spark/status/api/v1/StagesResource.scala | 3 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../org/apache/spark/status/storeTypes.scala | 327 ++++++- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 919 ++++++------------ ...mmary_w__custom_quantiles_expectation.json | 3 + ...sk_summary_w_shuffle_read_expectation.json | 3 + ...k_summary_w_shuffle_write_expectation.json | 3 + .../spark/status/AppStatusListenerSuite.scala | 105 +- .../spark/status/AppStatusStoreSuite.scala | 104 ++ .../org/apache/spark/ui/StagePageSuite.scala | 10 +- scalastyle-config.xml | 2 +- 18 files changed, 1361 insertions(+), 986 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 4f9e10ca20066..0e491efac9181 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -83,6 +83,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { + close(); throw new UnsupportedStoreVersionException(); } } else { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 88b75ddd5993a..b4edcf23abc09 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -377,6 +377,10 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + + val locality = event.taskInfo.taskLocality.toString() + val count = stage.localitySummary.getOrElse(locality, 0L) + 1L + stage.localitySummary = stage.localitySummary ++ Map(locality -> count) maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -433,7 +437,7 @@ private[spark] class AppStatusListener( } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task, now) + update(task, now, last = true) delta }.orNull @@ -450,7 +454,7 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { - stage.metrics.update(metricsDelta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } stage.activeTasks -= 1 stage.completedTasks += completedDelta @@ -486,7 +490,7 @@ private[spark] class AppStatusListener( esummary.failedTasks += failedDelta esummary.killedTasks += killedDelta if (metricsDelta != null) { - esummary.metrics.update(metricsDelta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } maybeUpdate(esummary, now) @@ -604,11 +608,11 @@ private[spark] class AppStatusListener( maybeUpdate(task, now) Option(liveStages.get((sid, sAttempt))).foreach { stage => - stage.metrics.update(delta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, delta) maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) - esummary.metrics.update(delta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, delta) maybeUpdate(esummary, now) } } @@ -690,7 +694,7 @@ private[spark] class AppStatusListener( // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + rdd.setStorageLevel(updatedStorageLevel.get) } val partition = rdd.partition(block.name) @@ -814,7 +818,7 @@ private[spark] class AppStatusListener( /** Update a live entity only if it hasn't been updated in the last configured period. */ private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { - if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { update(entity, now) } } @@ -865,7 +869,7 @@ private[spark] class AppStatusListener( } stages.foreach { s => - val key = s.id + val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) @@ -885,15 +889,15 @@ private[spark] class AppStatusListener( .asScala tasks.foreach { t => - kvstore.delete(t.getClass(), t.info.taskId) + kvstore.delete(t.getClass(), t.taskId) } // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) .index("stageId") - .first(s.stageId) - .last(s.stageId) + .first(s.info.stageId) + .last(s.info.stageId) .closeableIterator() val hasMoreAttempts = try { @@ -905,8 +909,10 @@ private[spark] class AppStatusListener( } if (!hasMoreAttempts) { - kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + kvstore.delete(classOf[RDDOperationGraphWrapper], s.info.stageId) } + + cleanupCachedQuantiles(key) } } @@ -919,9 +925,9 @@ private[spark] class AppStatusListener( // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => - !live || t.info.status != TaskState.RUNNING.toString() + !live || t.status != TaskState.RUNNING.toString() } - toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-toDelete.size) // If there are more running tasks than the configured limit, delete running tasks. This @@ -930,13 +936,34 @@ private[spark] class AppStatusListener( val remaining = countToDelete - toDelete.size if (remaining > 0) { val runningTasksToDelete = view.max(remaining).iterator().asScala.toList - runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-remaining) } + + // On live applications, cleanup any cached quantiles for the stage. This makes sure that + // quantiles will be recalculated after tasks are replaced with newer ones. + // + // This is not needed in the SHS since caching only happens after the event logs are + // completely processed. + if (live) { + cleanupCachedQuantiles(stageKey) + } } stage.cleaning = false } + private def cleanupCachedQuantiles(stageKey: Array[Int]): Unit = { + val cachedQuantiles = kvstore.view(classOf[CachedQuantile]) + .index("stage") + .first(stageKey) + .last(stageKey) + .asScala + .toList + cachedQuantiles.foreach { q => + kvstore.delete(q.getClass(), q.id) + } + } + /** * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done * asynchronously, this method may return 0 in case enough items have been deleted already. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5a942f5284018..efc28538a33db 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.Distribution +import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -98,7 +98,11 @@ private[spark] class AppStatusStore( val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) .closeableIterator() try { - it.next().info + if (it.hasNext()) { + it.next().info + } else { + throw new NoSuchElementException(s"No stage with id $stageId") + } } finally { it.close() } @@ -110,107 +114,238 @@ private[spark] class AppStatusStore( if (details) stageWithDetails(stage) else stage } + def taskCount(stageId: Int, stageAttemptId: Int): Long = { + store.count(classOf[TaskDataWrapper], "stage", Array(stageId, stageAttemptId)) + } + + def localitySummary(stageId: Int, stageAttemptId: Int): Map[String, Long] = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).locality + } + + /** + * Calculates a summary of the task metrics for the given stage attempt, returning the + * requested quantiles for the recorded metrics. + * + * This method can be expensive if the requested quantiles are not cached; the method + * will only cache certain quantiles (every 0.05 step), so it's recommended to stick to + * those to avoid expensive scans of all task data. + */ def taskSummary( stageId: Int, stageAttemptId: Int, - quantiles: Array[Double]): v1.TaskMetricDistributions = { - - val stage = Array(stageId, stageAttemptId) - - val rawMetrics = store.view(classOf[TaskDataWrapper]) - .index("stage") - .first(stage) - .last(stage) - .asScala - .flatMap(_.info.taskMetrics) - .toList - .view - - def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics = - new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics - - def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics = - new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics - - def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics = - new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( - readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles { s => - s.localBlocksFetched + s.remoteBlocksFetched - }, - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics = - new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = - raw.shuffleWriteMetrics - - def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new v1.TaskMetricDistributions( + unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { + val stageKey = Array(stageId, stageAttemptId) + val quantiles = unsortedQuantiles.sorted + + // We don't know how many tasks remain in the store that actually have metrics. So scan one + // metric and count how many valid tasks there are. Use skip() instead of next() since it's + // cheaper for disk stores (avoids deserialization). + val count = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + ) { it => + var _count = 0L + while (it.hasNext()) { + _count += 1 + it.skip(1) + } + _count + } + } + + if (count <= 0) { + return None + } + + // Find out which quantiles are already cached. The data in the store must match the expected + // task count to be considered, otherwise it will be re-scanned and overwritten. + val cachedQuantiles = quantiles.filter(shouldCacheQuantile).flatMap { q => + val qkey = Array(stageId, stageAttemptId, quantileToString(q)) + asOption(store.read(classOf[CachedQuantile], qkey)).filter(_.taskCount == count) + } + + // If there are no missing quantiles, return the data. Otherwise, just compute everything + // to make the code simpler. + if (cachedQuantiles.size == quantiles.size) { + def toValues(fn: CachedQuantile => Double): IndexedSeq[Double] = cachedQuantiles.map(fn) + + val distributions = new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = toValues(_.executorDeserializeTime), + executorDeserializeCpuTime = toValues(_.executorDeserializeCpuTime), + executorRunTime = toValues(_.executorRunTime), + executorCpuTime = toValues(_.executorCpuTime), + resultSize = toValues(_.resultSize), + jvmGcTime = toValues(_.jvmGcTime), + resultSerializationTime = toValues(_.resultSerializationTime), + gettingResultTime = toValues(_.gettingResultTime), + schedulerDelay = toValues(_.schedulerDelay), + peakExecutionMemory = toValues(_.peakExecutionMemory), + memoryBytesSpilled = toValues(_.memoryBytesSpilled), + diskBytesSpilled = toValues(_.diskBytesSpilled), + inputMetrics = new v1.InputMetricDistributions( + toValues(_.bytesRead), + toValues(_.recordsRead)), + outputMetrics = new v1.OutputMetricDistributions( + toValues(_.bytesWritten), + toValues(_.recordsWritten)), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + toValues(_.shuffleReadBytes), + toValues(_.shuffleRecordsRead), + toValues(_.shuffleRemoteBlocksFetched), + toValues(_.shuffleLocalBlocksFetched), + toValues(_.shuffleFetchWaitTime), + toValues(_.shuffleRemoteBytesRead), + toValues(_.shuffleRemoteBytesReadToDisk), + toValues(_.shuffleTotalBlocksFetched)), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + toValues(_.shuffleWriteBytes), + toValues(_.shuffleWriteRecords), + toValues(_.shuffleWriteTime))) + + return Some(distributions) + } + + // Compute quantiles by scanning the tasks in the store. This is not really stable for live + // stages (e.g. the number of recorded tasks may change while this code is running), but should + // stabilize once the stage finishes. It's also slow, especially with disk stores. + val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + + def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { + last + } else { + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } + } + }.toIndexedSeq + } + } + + val computedQuantiles = new v1.TaskMetricDistributions( quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGcTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) + executorDeserializeTime = scanTasks(TaskIndexNames.DESER_TIME) { t => + t.executorDeserializeTime + }, + executorDeserializeCpuTime = scanTasks(TaskIndexNames.DESER_CPU_TIME) { t => + t.executorDeserializeCpuTime + }, + executorRunTime = scanTasks(TaskIndexNames.EXEC_RUN_TIME) { t => t.executorRunTime }, + executorCpuTime = scanTasks(TaskIndexNames.EXEC_CPU_TIME) { t => t.executorCpuTime }, + resultSize = scanTasks(TaskIndexNames.RESULT_SIZE) { t => t.resultSize }, + jvmGcTime = scanTasks(TaskIndexNames.GC_TIME) { t => t.jvmGcTime }, + resultSerializationTime = scanTasks(TaskIndexNames.SER_TIME) { t => + t.resultSerializationTime + }, + gettingResultTime = scanTasks(TaskIndexNames.GETTING_RESULT_TIME) { t => + t.gettingResultTime + }, + schedulerDelay = scanTasks(TaskIndexNames.SCHEDULER_DELAY) { t => t.schedulerDelay }, + peakExecutionMemory = scanTasks(TaskIndexNames.PEAK_MEM) { t => t.peakExecutionMemory }, + memoryBytesSpilled = scanTasks(TaskIndexNames.MEM_SPILL) { t => t.memoryBytesSpilled }, + diskBytesSpilled = scanTasks(TaskIndexNames.DISK_SPILL) { t => t.diskBytesSpilled }, + inputMetrics = new v1.InputMetricDistributions( + scanTasks(TaskIndexNames.INPUT_SIZE) { t => t.inputBytesRead }, + scanTasks(TaskIndexNames.INPUT_RECORDS) { t => t.inputRecordsRead }), + outputMetrics = new v1.OutputMetricDistributions( + scanTasks(TaskIndexNames.OUTPUT_SIZE) { t => t.outputBytesWritten }, + scanTasks(TaskIndexNames.OUTPUT_RECORDS) { t => t.outputRecordsWritten }), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_READS) { m => + m.shuffleLocalBytesRead + m.shuffleRemoteBytesRead + }, + scanTasks(TaskIndexNames.SHUFFLE_READ_RECORDS) { t => t.shuffleRecordsRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_BLOCKS) { t => t.shuffleRemoteBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_LOCAL_BLOCKS) { t => t.shuffleLocalBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_READ_TIME) { t => t.shuffleFetchWaitTime }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS) { t => t.shuffleRemoteBytesRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK) { t => + t.shuffleRemoteBytesReadToDisk + }, + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_BLOCKS) { m => + m.shuffleLocalBlocksFetched + m.shuffleRemoteBlocksFetched + }), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_WRITE_SIZE) { t => t.shuffleBytesWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_RECORDS) { t => t.shuffleRecordsWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_TIME) { t => t.shuffleWriteTime })) + + // Go through the computed quantiles and cache the values that match the caching criteria. + computedQuantiles.quantiles.zipWithIndex + .filter { case (q, _) => quantiles.contains(q) && shouldCacheQuantile(q) } + .foreach { case (q, idx) => + val cached = new CachedQuantile(stageId, stageAttemptId, quantileToString(q), count, + executorDeserializeTime = computedQuantiles.executorDeserializeTime(idx), + executorDeserializeCpuTime = computedQuantiles.executorDeserializeCpuTime(idx), + executorRunTime = computedQuantiles.executorRunTime(idx), + executorCpuTime = computedQuantiles.executorCpuTime(idx), + resultSize = computedQuantiles.resultSize(idx), + jvmGcTime = computedQuantiles.jvmGcTime(idx), + resultSerializationTime = computedQuantiles.resultSerializationTime(idx), + gettingResultTime = computedQuantiles.gettingResultTime(idx), + schedulerDelay = computedQuantiles.schedulerDelay(idx), + peakExecutionMemory = computedQuantiles.peakExecutionMemory(idx), + memoryBytesSpilled = computedQuantiles.memoryBytesSpilled(idx), + diskBytesSpilled = computedQuantiles.diskBytesSpilled(idx), + + bytesRead = computedQuantiles.inputMetrics.bytesRead(idx), + recordsRead = computedQuantiles.inputMetrics.recordsRead(idx), + + bytesWritten = computedQuantiles.outputMetrics.bytesWritten(idx), + recordsWritten = computedQuantiles.outputMetrics.recordsWritten(idx), + + shuffleReadBytes = computedQuantiles.shuffleReadMetrics.readBytes(idx), + shuffleRecordsRead = computedQuantiles.shuffleReadMetrics.readRecords(idx), + shuffleRemoteBlocksFetched = + computedQuantiles.shuffleReadMetrics.remoteBlocksFetched(idx), + shuffleLocalBlocksFetched = computedQuantiles.shuffleReadMetrics.localBlocksFetched(idx), + shuffleFetchWaitTime = computedQuantiles.shuffleReadMetrics.fetchWaitTime(idx), + shuffleRemoteBytesRead = computedQuantiles.shuffleReadMetrics.remoteBytesRead(idx), + shuffleRemoteBytesReadToDisk = + computedQuantiles.shuffleReadMetrics.remoteBytesReadToDisk(idx), + shuffleTotalBlocksFetched = computedQuantiles.shuffleReadMetrics.totalBlocksFetched(idx), + + shuffleWriteBytes = computedQuantiles.shuffleWriteMetrics.writeBytes(idx), + shuffleWriteRecords = computedQuantiles.shuffleWriteMetrics.writeRecords(idx), + shuffleWriteTime = computedQuantiles.shuffleWriteMetrics.writeTime(idx)) + store.write(cached) + } + + Some(computedQuantiles) } + /** + * Whether to cache information about a specific metric quantile. We cache quantiles at every 0.05 + * step, which covers the default values used both in the API and in the stages page. + */ + private def shouldCacheQuantile(q: Double): Boolean = (math.round(q * 100) % 5) == 0 + + private def quantileToString(q: Double): String = math.round(q * 100).toString + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.info).toSeq.reverse + .max(maxTasks).asScala.map(_.toApi).toSeq.reverse } def taskList( @@ -219,18 +354,43 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val (indexName, ascending) = sortBy match { + case v1.TaskSorting.ID => + (None, true) + case v1.TaskSorting.INCREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), true) + case v1.TaskSorting.DECREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), false) + } + taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: Option[String], + ascending: Boolean): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { - case v1.TaskSorting.ID => + case Some(index) => + base.index(index).parent(stageKey) + + case _ => + // Sort by ID, which is the "stage" index. base.index("stage").first(stageKey).last(stageKey) - case v1.TaskSorting.INCREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) - case v1.TaskSorting.DECREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) - .reverse() } - indexed.skip(offset).max(length).asScala.map(_.info).toSeq + + val ordered = if (ascending) indexed else indexed.reverse() + ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + } + + def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { + val stageKey = Array(stageId, attemptId) + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey).last(stageKey) + .asScala.map { exec => (exec.executorId -> exec.info) }.toMap } def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { @@ -256,12 +416,6 @@ private[spark] class AppStatusStore( .map { t => (t.taskId, t) } .toMap - val stageKey = Array(stage.stageId, stage.attemptId) - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) - .last(stageKey).closeableIterator().asScala - .map { exec => (exec.executorId -> exec.info) } - .toMap - new v1.StageData( stage.status, stage.stageId, @@ -295,7 +449,7 @@ private[spark] class AppStatusStore( stage.rddIds, stage.accumulatorUpdates, Some(tasks), - Some(execs), + Some(executorSummary(stage.stageId, stage.attemptId)), stage.killedTasksSummary) } @@ -352,22 +506,3 @@ private[spark] object AppStatusStore { } } - -/** - * Helper for getting distributions from nested metric types. - */ -private abstract class MetricHelper[I, O]( - rawMetrics: Seq[v1.TaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: v1.TaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala new file mode 100644 index 0000000000000..341bd4e0cd016 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +private[spark] object AppStatusUtils { + + def schedulerDelay(task: TaskData): Long = { + if (task.taskMetrics.isDefined && task.duration.isDefined) { + val m = task.taskMetrics.get + schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, + m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) + } else { + 0L + } + } + + def gettingResultTime(task: TaskData): Long = { + gettingResultTime(task.launchTime.getTime(), fetchStart(task), task.duration.getOrElse(-1L)) + } + + def schedulerDelay( + launchTime: Long, + fetchStart: Long, + duration: Long, + deserializeTime: Long, + serializeTime: Long, + runTime: Long): Long = { + math.max(0, duration - runTime - deserializeTime - serializeTime - + gettingResultTime(launchTime, fetchStart, duration)) + } + + def gettingResultTime(launchTime: Long, fetchStart: Long, duration: Long): Long = { + if (fetchStart > 0) { + if (duration > 0) { + launchTime + duration - fetchStart + } else { + System.currentTimeMillis() - fetchStart + } + } else { + 0L + } + } + + private def fetchStart(task: TaskData): Long = { + if (task.resultFetchStart.isDefined) { + task.resultFetchStart.get.getTime() + } else { + -1 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 305c2fafa6aac..4295e664e131c 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} @@ -119,7 +121,9 @@ private class LiveTask( import LiveEntityHelpers._ - private var recordedMetrics: v1.TaskMetrics = null + // The task metrics use a special value when no metrics have been reported. The special value is + // checked when calculating indexed values when writing to the store (see [[TaskDataWrapper]]). + private var metrics: v1.TaskMetrics = createMetrics(default = -1L) var errorMessage: Option[String] = None @@ -129,8 +133,8 @@ private class LiveTask( */ def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { if (metrics != null) { - val old = recordedMetrics - recordedMetrics = new v1.TaskMetrics( + val old = this.metrics + val newMetrics = createMetrics( metrics.executorDeserializeTime, metrics.executorDeserializeCpuTime, metrics.executorRunTime, @@ -141,73 +145,35 @@ private class LiveTask( metrics.memoryBytesSpilled, metrics.diskBytesSpilled, metrics.peakExecutionMemory, - new v1.InputMetrics( - metrics.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead), - new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten), - new v1.ShuffleReadMetrics( - metrics.shuffleReadMetrics.remoteBlocksFetched, - metrics.shuffleReadMetrics.localBlocksFetched, - metrics.shuffleReadMetrics.fetchWaitTime, - metrics.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead), - new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten, - metrics.shuffleWriteMetrics.writeTime, - metrics.shuffleWriteMetrics.recordsWritten)) - if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten) + + this.metrics = newMetrics + + // Only calculate the delta if the old metrics contain valid information, otherwise + // the new metrics are the delta. + if (old.executorDeserializeTime >= 0L) { + subtractMetrics(newMetrics, old) + } else { + newMetrics + } } else { null } } - /** - * Return a new TaskMetrics object containing the delta of the various fields of the given - * metrics objects. This is currently targeted at updating stage data, so it does not - * necessarily calculate deltas for all the fields. - */ - private def calculateMetricsDelta( - metrics: v1.TaskMetrics, - old: v1.TaskMetrics): v1.TaskMetrics = { - val shuffleWriteDelta = new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, - 0L, - metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) - - val shuffleReadDelta = new v1.ShuffleReadMetrics( - 0L, 0L, 0L, - metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk - - old.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) - - val inputDelta = new v1.InputMetrics( - metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) - - val outputDelta = new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) - - new v1.TaskMetrics( - 0L, 0L, - metrics.executorRunTime - old.executorRunTime, - metrics.executorCpuTime - old.executorCpuTime, - 0L, 0L, 0L, - metrics.memoryBytesSpilled - old.memoryBytesSpilled, - metrics.diskBytesSpilled - old.diskBytesSpilled, - 0L, - inputDelta, - outputDelta, - shuffleReadDelta, - shuffleWriteDelta) - } - override protected def doUpdate(): Any = { val duration = if (info.finished) { info.duration @@ -215,22 +181,48 @@ private class LiveTask( info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) } - val task = new v1.TaskData( + new TaskDataWrapper( info.taskId, info.index, info.attemptNumber, - new Date(info.launchTime), - if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, - Some(duration), - info.executorId, - info.host, - info.status, - info.taskLocality.toString(), + info.launchTime, + if (info.gettingResult) info.gettingResultTime else -1L, + duration, + weakIntern(info.executorId), + weakIntern(info.host), + weakIntern(info.status), + weakIntern(info.taskLocality.toString()), info.speculative, newAccumulatorInfos(info.accumulables), errorMessage, - Option(recordedMetrics)) - new TaskDataWrapper(task, stageId, stageAttemptId) + + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGcTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + metrics.peakExecutionMemory, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten, + + stageId, + stageAttemptId) } } @@ -313,50 +305,19 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE } -/** Metrics tracked per stage (both total and per executor). */ -private class MetricsTracker { - var executorRunTime = 0L - var executorCpuTime = 0L - var inputBytes = 0L - var inputRecords = 0L - var outputBytes = 0L - var outputRecords = 0L - var shuffleReadBytes = 0L - var shuffleReadRecords = 0L - var shuffleWriteBytes = 0L - var shuffleWriteRecords = 0L - var memoryBytesSpilled = 0L - var diskBytesSpilled = 0L - - def update(delta: v1.TaskMetrics): Unit = { - executorRunTime += delta.executorRunTime - executorCpuTime += delta.executorCpuTime - inputBytes += delta.inputMetrics.bytesRead - inputRecords += delta.inputMetrics.recordsRead - outputBytes += delta.outputMetrics.bytesWritten - outputRecords += delta.outputMetrics.recordsWritten - shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + - delta.shuffleReadMetrics.remoteBytesRead - shuffleReadRecords += delta.shuffleReadMetrics.recordsRead - shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten - shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten - memoryBytesSpilled += delta.memoryBytesSpilled - diskBytesSpilled += delta.diskBytesSpilled - } - -} - private class LiveExecutorStageSummary( stageId: Int, attemptId: Int, executorId: String) extends LiveEntity { + import LiveEntityHelpers._ + var taskTime = 0L var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 - val metrics = new MetricsTracker() + var metrics = createMetrics(default = 0L) override protected def doUpdate(): Any = { val info = new v1.ExecutorStageSummary( @@ -364,14 +325,14 @@ private class LiveExecutorStageSummary( failedTasks, succeededTasks, killedTasks, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBytesRead + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -402,7 +363,9 @@ private class LiveStage extends LiveEntity { var firstLaunchTime = Long.MaxValue - val metrics = new MetricsTracker() + var localitySummary: Map[String, Long] = Map() + + var metrics = createMetrics(default = 0L) val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() @@ -435,14 +398,14 @@ private class LiveStage extends LiveEntity { info.completionTime.map(new Date(_)), info.failureReason, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.localBytesRead + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, @@ -459,13 +422,15 @@ private class LiveStage extends LiveEntity { } override protected def doUpdate(): Any = { - new StageDataWrapper(toApi(), jobIds) + new StageDataWrapper(toApi(), jobIds, localitySummary) } } private class LiveRDDPartition(val blockName: String) { + import LiveEntityHelpers._ + // Pointers used by RDDPartitionSeq. @volatile var prev: LiveRDDPartition = null @volatile var next: LiveRDDPartition = null @@ -485,7 +450,7 @@ private class LiveRDDPartition(val blockName: String) { diskUsed: Long): Unit = { value = new v1.RDDPartitionInfo( blockName, - storageLevel, + weakIntern(storageLevel), memoryUsed, diskUsed, executors) @@ -495,6 +460,8 @@ private class LiveRDDPartition(val blockName: String) { private class LiveRDDDistribution(exec: LiveExecutor) { + import LiveEntityHelpers._ + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L @@ -508,7 +475,7 @@ private class LiveRDDDistribution(exec: LiveExecutor) { def toApi(): v1.RDDDataDistribution = { if (lastUpdate == null) { lastUpdate = new v1.RDDDataDistribution( - exec.hostPort, + weakIntern(exec.hostPort), memoryUsed, exec.maxMemory - exec.memoryUsed, diskUsed, @@ -524,7 +491,9 @@ private class LiveRDDDistribution(exec: LiveExecutor) { private class LiveRDD(val info: RDDInfo) extends LiveEntity { - var storageLevel: String = info.storageLevel.description + import LiveEntityHelpers._ + + var storageLevel: String = weakIntern(info.storageLevel.description) var memoryUsed = 0L var diskUsed = 0L @@ -533,6 +502,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { private val distributions = new HashMap[String, LiveRDDDistribution]() + def setStorageLevel(level: String): Unit = { + this.storageLevel = weakIntern(level) + } + def partition(blockName: String): LiveRDDPartition = { partitions.getOrElseUpdate(blockName, { val part = new LiveRDDPartition(blockName) @@ -593,6 +566,9 @@ private class SchedulerPool(name: String) extends LiveEntity { private object LiveEntityHelpers { + private val stringInterner = Interners.newWeakInterner[String]() + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { accums .filter { acc => @@ -604,13 +580,119 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.orNull, + acc.name.map(weakIntern).orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } .toSeq } + /** String interning to reduce the memory usage. */ + def weakIntern(s: String): String = { + stringInterner.intern(s) + } + + // scalastyle:off argcount + def createMetrics( + executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, + executorRunTime: Long, + executorCpuTime: Long, + resultSize: Long, + jvmGcTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputBytesRead: Long, + inputRecordsRead: Long, + outputBytesWritten: Long, + outputRecordsWritten: Long, + shuffleRemoteBlocksFetched: Long, + shuffleLocalBlocksFetched: Long, + shuffleFetchWaitTime: Long, + shuffleRemoteBytesRead: Long, + shuffleRemoteBytesReadToDisk: Long, + shuffleLocalBytesRead: Long, + shuffleRecordsRead: Long, + shuffleBytesWritten: Long, + shuffleWriteTime: Long, + shuffleRecordsWritten: Long): v1.TaskMetrics = { + new v1.TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new v1.InputMetrics( + inputBytesRead, + inputRecordsRead), + new v1.OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new v1.ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new v1.ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten)) + } + // scalastyle:on argcount + + def createMetrics(default: Long): v1.TaskMetrics = { + createMetrics(default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default) + } + + /** Add m2 values to m1. */ + def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = addMetrics(m1, m2, 1) + + /** Subtract m2 values from m1. */ + def subtractMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = { + addMetrics(m1, m2, -1) + } + + private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = { + createMetrics( + m1.executorDeserializeTime + m2.executorDeserializeTime * mult, + m1.executorDeserializeCpuTime + m2.executorDeserializeCpuTime * mult, + m1.executorRunTime + m2.executorRunTime * mult, + m1.executorCpuTime + m2.executorCpuTime * mult, + m1.resultSize + m2.resultSize * mult, + m1.jvmGcTime + m2.jvmGcTime * mult, + m1.resultSerializationTime + m2.resultSerializationTime * mult, + m1.memoryBytesSpilled + m2.memoryBytesSpilled * mult, + m1.diskBytesSpilled + m2.diskBytesSpilled * mult, + m1.peakExecutionMemory + m2.peakExecutionMemory * mult, + m1.inputMetrics.bytesRead + m2.inputMetrics.bytesRead * mult, + m1.inputMetrics.recordsRead + m2.inputMetrics.recordsRead * mult, + m1.outputMetrics.bytesWritten + m2.outputMetrics.bytesWritten * mult, + m1.outputMetrics.recordsWritten + m2.outputMetrics.recordsWritten * mult, + m1.shuffleReadMetrics.remoteBlocksFetched + m2.shuffleReadMetrics.remoteBlocksFetched * mult, + m1.shuffleReadMetrics.localBlocksFetched + m2.shuffleReadMetrics.localBlocksFetched * mult, + m1.shuffleReadMetrics.fetchWaitTime + m2.shuffleReadMetrics.fetchWaitTime * mult, + m1.shuffleReadMetrics.remoteBytesRead + m2.shuffleReadMetrics.remoteBytesRead * mult, + m1.shuffleReadMetrics.remoteBytesReadToDisk + + m2.shuffleReadMetrics.remoteBytesReadToDisk * mult, + m1.shuffleReadMetrics.localBytesRead + m2.shuffleReadMetrics.localBytesRead * mult, + m1.shuffleReadMetrics.recordsRead + m2.shuffleReadMetrics.recordsRead * mult, + m1.shuffleWriteMetrics.bytesWritten + m2.shuffleWriteMetrics.bytesWritten * mult, + m1.shuffleWriteMetrics.writeTime + m2.shuffleWriteMetrics.writeTime * mult, + m1.shuffleWriteMetrics.recordsWritten + m2.shuffleWriteMetrics.recordsWritten * mult) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 3b879545b3d2e..96249e4bfd5fa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -87,7 +87,8 @@ private[v1] class StagesResource extends BaseAppResource { } } - ui.store.taskSummary(stageId, stageAttemptId, quantiles) + ui.store.taskSummary(stageId, stageAttemptId, quantiles).getOrElse( + throw new NotFoundException(s"No tasks reported metrics for $stageId / $stageAttemptId yet.")) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 45eaf935fb083..7d8e4de3c8efb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -261,6 +261,9 @@ class TaskMetricDistributions private[spark]( val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], + val gettingResultTime: IndexedSeq[Double], + val schedulerDelay: IndexedSeq[Double], + val peakExecutionMemory: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 1cfd30df49091..c9cb996a55fcc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,9 +17,11 @@ package org.apache.spark.status -import java.lang.{Integer => JInteger, Long => JLong} +import java.lang.{Long => JLong} +import java.util.Date import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ @@ -49,10 +51,10 @@ private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvi private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex - private[this] val id: String = info.id + private def id: String = info.id @JsonIgnore @KVIndex("active") - private[this] val active: Boolean = info.isActive + private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") val host: String = info.hostPort.split(":")(0) @@ -69,51 +71,271 @@ private[spark] class JobDataWrapper( val skippedStages: Set[Int]) { @JsonIgnore @KVIndex - private[this] val id: Int = info.jobId + private def id: Int = info.jobId } private[spark] class StageDataWrapper( val info: StageData, - val jobIds: Set[Int]) { + val jobIds: Set[Int], + @JsonDeserialize(contentAs = classOf[JLong]) + val locality: Map[String, Long]) { @JsonIgnore @KVIndex - def id: Array[Int] = Array(info.stageId, info.attemptId) + private[this] val id: Array[Int] = Array(info.stageId, info.attemptId) @JsonIgnore @KVIndex("stageId") - def stageId: Int = info.stageId + private def stageId: Int = info.stageId + @JsonIgnore @KVIndex("active") + private def active: Boolean = info.status == StageStatus.ACTIVE + +} + +/** + * Tasks have a lot of indices that are used in a few different places. This object keeps logical + * names for these indices, mapped to short strings to save space when using a disk store. + */ +private[spark] object TaskIndexNames { + final val ACCUMULATORS = "acc" + final val ATTEMPT = "att" + final val DESER_CPU_TIME = "dct" + final val DESER_TIME = "des" + final val DISK_SPILL = "dbs" + final val DURATION = "dur" + final val ERROR = "err" + final val EXECUTOR = "exe" + final val EXEC_CPU_TIME = "ect" + final val EXEC_RUN_TIME = "ert" + final val GC_TIME = "gc" + final val GETTING_RESULT_TIME = "grt" + final val INPUT_RECORDS = "ir" + final val INPUT_SIZE = "is" + final val LAUNCH_TIME = "lt" + final val LOCALITY = "loc" + final val MEM_SPILL = "mbs" + final val OUTPUT_RECORDS = "or" + final val OUTPUT_SIZE = "os" + final val PEAK_MEM = "pem" + final val RESULT_SIZE = "rs" + final val SCHEDULER_DELAY = "dly" + final val SER_TIME = "rst" + final val SHUFFLE_LOCAL_BLOCKS = "slbl" + final val SHUFFLE_READ_RECORDS = "srr" + final val SHUFFLE_READ_TIME = "srt" + final val SHUFFLE_REMOTE_BLOCKS = "srbl" + final val SHUFFLE_REMOTE_READS = "srby" + final val SHUFFLE_REMOTE_READS_TO_DISK = "srbd" + final val SHUFFLE_TOTAL_READS = "stby" + final val SHUFFLE_TOTAL_BLOCKS = "stbl" + final val SHUFFLE_WRITE_RECORDS = "swr" + final val SHUFFLE_WRITE_SIZE = "sws" + final val SHUFFLE_WRITE_TIME = "swt" + final val STAGE = "stage" + final val STATUS = "sta" + final val TASK_INDEX = "idx" } /** - * The task information is always indexed with the stage ID, since that is how the UI and API - * consume it. That means every indexed value has the stage ID and attempt ID included, aside - * from the actual data being indexed. + * Unlike other data types, the task data wrapper does not keep a reference to the API's TaskData. + * That is to save memory, since for large applications there can be a large number of these + * elements (by default up to 100,000 per stage), and every bit of wasted memory adds up. + * + * It also contains many secondary indices, which are used to sort data efficiently in the UI at the + * expense of storage space (and slower write times). */ private[spark] class TaskDataWrapper( - val info: TaskData, + // Storing this as an object actually saves memory; it's also used as the key in the in-memory + // store, so in that case you'd save the extra copy of the value here. + @KVIndexParam + val taskId: JLong, + @KVIndexParam(value = TaskIndexNames.TASK_INDEX, parent = TaskIndexNames.STAGE) + val index: Int, + @KVIndexParam(value = TaskIndexNames.ATTEMPT, parent = TaskIndexNames.STAGE) + val attempt: Int, + @KVIndexParam(value = TaskIndexNames.LAUNCH_TIME, parent = TaskIndexNames.STAGE) + val launchTime: Long, + val resultFetchStart: Long, + @KVIndexParam(value = TaskIndexNames.DURATION, parent = TaskIndexNames.STAGE) + val duration: Long, + @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) + val executorId: String, + val host: String, + @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) + val status: String, + @KVIndexParam(value = TaskIndexNames.LOCALITY, parent = TaskIndexNames.STAGE) + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String], + + // The following is an exploded view of a TaskMetrics API object. This saves 5 objects + // (= 80 bytes of Java object overhead) per instance of this wrapper. If the first value + // (executorDeserializeTime) is -1L, it means the metrics for this task have not been + // recorded. + @KVIndexParam(value = TaskIndexNames.DESER_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeTime: Long, + @KVIndexParam(value = TaskIndexNames.DESER_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_RUN_TIME, parent = TaskIndexNames.STAGE) + val executorRunTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.RESULT_SIZE, parent = TaskIndexNames.STAGE) + val resultSize: Long, + @KVIndexParam(value = TaskIndexNames.GC_TIME, parent = TaskIndexNames.STAGE) + val jvmGcTime: Long, + @KVIndexParam(value = TaskIndexNames.SER_TIME, parent = TaskIndexNames.STAGE) + val resultSerializationTime: Long, + @KVIndexParam(value = TaskIndexNames.MEM_SPILL, parent = TaskIndexNames.STAGE) + val memoryBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.DISK_SPILL, parent = TaskIndexNames.STAGE) + val diskBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.PEAK_MEM, parent = TaskIndexNames.STAGE) + val peakExecutionMemory: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_SIZE, parent = TaskIndexNames.STAGE) + val inputBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_RECORDS, parent = TaskIndexNames.STAGE) + val inputRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_SIZE, parent = TaskIndexNames.STAGE) + val outputBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_RECORDS, parent = TaskIndexNames.STAGE) + val outputRecordsWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_LOCAL_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleLocalBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_TIME, parent = TaskIndexNames.STAGE) + val shuffleFetchWaitTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK, + parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesReadToDisk: Long, + val shuffleLocalBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_SIZE, parent = TaskIndexNames.STAGE) + val shuffleBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_TIME, parent = TaskIndexNames.STAGE) + val shuffleWriteTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsWritten: Long, + val stageId: Int, val stageAttemptId: Int) { - @JsonIgnore @KVIndex - def id: Long = info.taskId + def hasMetrics: Boolean = executorDeserializeTime >= 0 + + def toApi: TaskData = { + val metrics = if (hasMetrics) { + Some(new TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new InputMetrics( + inputBytesRead, + inputRecordsRead), + new OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten))) + } else { + None + } - @JsonIgnore @KVIndex("stage") - def stage: Array[Int] = Array(stageId, stageAttemptId) + new TaskData( + taskId, + index, + attempt, + new Date(launchTime), + if (resultFetchStart > 0L) Some(new Date(resultFetchStart)) else None, + if (duration > 0L) Some(duration) else None, + executorId, + host, + status, + taskLocality, + speculative, + accumulatorUpdates, + errorMessage, + metrics) + } + + @JsonIgnore @KVIndex(TaskIndexNames.STAGE) + private def stage: Array[Int] = Array(stageId, stageAttemptId) - @JsonIgnore @KVIndex("runtime") - def runtime: Array[AnyRef] = { - val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) - Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.SCHEDULER_DELAY, parent = TaskIndexNames.STAGE) + def schedulerDelay: Long = { + if (hasMetrics) { + AppStatusUtils.schedulerDelay(launchTime, resultFetchStart, duration, executorDeserializeTime, + resultSerializationTime, executorRunTime) + } else { + -1L + } } - @JsonIgnore @KVIndex("startTime") - def startTime: Array[AnyRef] = { - Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.GETTING_RESULT_TIME, parent = TaskIndexNames.STAGE) + def gettingResultTime: Long = { + if (hasMetrics) { + AppStatusUtils.gettingResultTime(launchTime, resultFetchStart, duration) + } else { + -1L + } } - @JsonIgnore @KVIndex("active") - def active: Boolean = info.duration.isEmpty + /** + * Sorting by accumulators is a little weird, and the previous behavior would generate + * insanely long keys in the index. So this implementation just considers the first + * accumulator and its String representation. + */ + @JsonIgnore @KVIndex(value = TaskIndexNames.ACCUMULATORS, parent = TaskIndexNames.STAGE) + private def accumulators: String = { + if (accumulatorUpdates.nonEmpty) { + val acc = accumulatorUpdates.head + s"${acc.name}:${acc.value}" + } else { + "" + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) + private def shuffleTotalReads: Long = { + if (hasMetrics) { + shuffleLocalBytesRead + shuffleRemoteBytesRead + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) + private def shuffleTotalBlocks: Long = { + if (hasMetrics) { + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) + private def error: String = if (errorMessage.isDefined) errorMessage.get else "" } @@ -134,10 +356,13 @@ private[spark] class ExecutorStageSummaryWrapper( val info: ExecutorStageSummary) { @JsonIgnore @KVIndex - val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + private val _id: Array[Any] = Array(stageId, stageAttemptId, executorId) @JsonIgnore @KVIndex("stage") - private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + private def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore + def id: Array[Any] = _id } @@ -203,3 +428,53 @@ private[spark] class AppSummary( def id: String = classOf[AppSummary].getName() } + +/** + * A cached view of a specific quantile for one stage attempt's metrics. + */ +private[spark] class CachedQuantile( + val stageId: Int, + val stageAttemptId: Int, + val quantile: String, + val taskCount: Long, + + // The following fields are an exploded view of a single entry for TaskMetricDistributions. + val executorDeserializeTime: Double, + val executorDeserializeCpuTime: Double, + val executorRunTime: Double, + val executorCpuTime: Double, + val resultSize: Double, + val jvmGcTime: Double, + val resultSerializationTime: Double, + val gettingResultTime: Double, + val schedulerDelay: Double, + val peakExecutionMemory: Double, + val memoryBytesSpilled: Double, + val diskBytesSpilled: Double, + + val bytesRead: Double, + val recordsRead: Double, + + val bytesWritten: Double, + val recordsWritten: Double, + + val shuffleReadBytes: Double, + val shuffleRecordsRead: Double, + val shuffleRemoteBlocksFetched: Double, + val shuffleLocalBlocksFetched: Double, + val shuffleFetchWaitTime: Double, + val shuffleRemoteBytesRead: Double, + val shuffleRemoteBytesReadToDisk: Double, + val shuffleTotalBlocksFetched: Double, + + val shuffleWriteBytes: Double, + val shuffleWriteRecords: Double, + val shuffleWriteTime: Double) { + + @KVIndex @JsonIgnore + def id: Array[Any] = Array(stageId, stageAttemptId, quantile) + + @KVIndex("stage") @JsonIgnore + def stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 41d42b52430a5..95c12b1e73653 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -87,7 +87,9 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } private def createExecutorTable(stage: StageData) : Seq[Node] = { - stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) + + executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k))
    - } - } - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { - getDistributionQuantiles(data).map(d => ) + val summaryTable = metricsSummary.map { metrics => + def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { millis => + } + } - val deserializationTimes = validTasks.map { task => - task.taskMetrics.get.executorDeserializeTime.toDouble - } - val deserializationQuantiles = - +: getFormattedTimeQuantiles(deserializationTimes) - - val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) - val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) - - val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) - val gcQuantiles = - +: getFormattedTimeQuantiles(gcTimes) - - val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) - val serializationQuantiles = - +: getFormattedTimeQuantiles(serializationTimes) - - val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) - val gettingResultQuantiles = - +: - getFormattedTimeQuantiles(gettingResultTimes) - - val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) - val peakExecutionMemoryQuantiles = { - +: getFormattedSizeQuantiles(peakExecutionMemory) + def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { size => + } + } - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { task => - getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble - } - val schedulerDelayTitle = - val schedulerDelayQuantiles = schedulerDelayTitle +: - getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) - : Seq[Elem] = { - val recordDist = getDistributionQuantiles(records).iterator - getDistributionQuantiles(data).map(d => - - ) + def sizeQuantilesWithRecords( + data: IndexedSeq[Double], + records: IndexedSeq[Double]) : Seq[Node] = { + data.zip(records).map { case (d, r) => + } + } - val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) - val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) - val inputQuantiles = +: - getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) + def titleCell(title: String, tooltip: String): Seq[Node] = { + + } - val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) - val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) - val outputQuantiles = +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + def simpleTitleCell(title: String): Seq[Node] = - val shuffleReadBlockedTimes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) - - val shuffleReadTotalSizes = validTasks.map { task => - totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble - } - val shuffleReadTotalRecords = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - - val shuffleReadRemoteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble - } - val shuffleReadRemoteQuantiles = - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) - - val shuffleWriteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationQuantiles = titleCell("Task Deserialization Time", + ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - val shuffleWriteRecords = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - val shuffleWriteQuantiles = +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) - val memoryBytesSpilledQuantiles = +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val serializationQuantiles = titleCell("Result Serialization Time", + ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) - val diskBytesSpilledQuantiles = +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) + val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ + timeQuantiles(metrics.gettingResultTime) - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", + ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) + + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ + timeQuantiles(metrics.schedulerDelay) + + def inputQuantiles: Seq[Node] = { + simpleTitleCell("Input Size / Records") ++ + sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) + } + + def outputQuantiles: Seq[Node] = { + simpleTitleCell("Output Size / Records") ++ + sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten) } + def shuffleReadBlockedQuantiles: Seq[Node] = { + titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ + timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) + } + + def shuffleReadTotalQuantiles: Seq[Node] = { + titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ + sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, + metrics.shuffleReadMetrics.readRecords) + } + + def shuffleReadRemoteQuantiles: Seq[Node] = { + titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ + sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) + } + + def shuffleWriteQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle Write Size / Records") ++ + sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, + metrics.shuffleWriteMetrics.writeRecords) + } + + def memoryBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) + } + + def diskBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) + } + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", + "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false) + } + val executorTable = new ExecutorTable(stageData, parent.store) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() val aggMetrics = taskIdsInPage.contains(t.taskId) }, + Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ @@ -593,10 +523,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskInfo, currentTime) + val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = - metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime @@ -708,7 +637,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We { if (MAX_TIMELINE_TASKS < tasks.size) { - This stage has more than the maximum number of tasks that can be shown in the + This page has more than the maximum number of tasks that can be shown in the visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks (of {tasks.size} total) are shown. @@ -733,402 +662,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } -private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { - info.resultFetchStart match { - case Some(start) => - info.duration match { - case Some(duration) => - info.launchTime.getTime() + duration - start.getTime() - - case _ => - currentTime - start.getTime() - } - - case _ => - 0L - } - } - - private[ui] def getSchedulerDelay( - info: TaskData, - metrics: TaskMetrics, - currentTime: Long): Long = { - info.duration match { - case Some(duration) => - val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime - math.max( - 0, - duration - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - - case _ => - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } - -} - -private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) - -private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) - -private[ui] case class TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable: Long, - shuffleReadBlockedTimeReadable: String, - shuffleReadSortable: Long, - shuffleReadReadable: String, - shuffleReadRemoteSortable: Long, - shuffleReadRemoteReadable: String) - -private[ui] case class TaskTableRowShuffleWriteData( - writeTimeSortable: Long, - writeTimeReadable: String, - shuffleWriteSortable: Long, - shuffleWriteReadable: String) - -private[ui] case class TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable: Long, - memoryBytesSpilledReadable: String, - diskBytesSpilledSortable: Long, - diskBytesSpilledReadable: String) - -/** - * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskData to avoid creating duplicate contents during sorting the data. - */ -private[ui] class TaskTableRowData( - val index: Int, - val taskId: Long, - val attempt: Int, - val speculative: Boolean, - val status: String, - val taskLocality: String, - val executorId: String, - val host: String, - val launchTime: Long, - val duration: Long, - val formatDuration: String, - val schedulerDelay: Long, - val taskDeserializationTime: Long, - val gcTime: Long, - val serializationTime: Long, - val gettingResultTime: Long, - val peakExecutionMemoryUsed: Long, - val accumulators: Option[String], // HTML - val input: Option[TaskTableRowInputData], - val output: Option[TaskTableRowOutputData], - val shuffleRead: Option[TaskTableRowShuffleReadData], - val shuffleWrite: Option[TaskTableRowShuffleWriteData], - val bytesSpilled: Option[TaskTableRowBytesSpilledData], - val error: String, - val logs: Map[String, String]) - private[ui] class TaskDataSource( - tasks: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, + stage: StageData, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { - import StagePage._ + store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) { + import ApiHelper._ // Keep an internal cache of executor log maps so that long task lists render faster. private val executorIdToLogs = new HashMap[String, Map[String, String]]() - // Convert TaskData to TaskTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data - private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - - private var _slicedTaskIds: Set[Long] = _ + private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = data.size + override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks - override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { - val r = data.slice(from, to) - _slicedTaskIds = r.map(_.taskId).toSet - r - } - - def slicedTaskIds: Set[Long] = _slicedTaskIds - - private def taskRow(info: TaskData): TaskTableRowData = { - val metrics = info.taskMetrics - val duration = info.duration.getOrElse(1L) - val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val externalAccumulableReadable = info.accumulatorUpdates.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + override def sliceData(from: Int, to: Int): Seq[TaskData] = { + if (_tasksToShow == null) { + _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from, + indexName(sortColumn), !desc) } - val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - - val maybeInput = metrics.map(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)}") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.map(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.recordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) - val writeTimeSortable = maybeWriteTime.getOrElse(0L) - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val input = - if (hasInput) { - Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) - } else { - None - } - - val output = - if (hasOutput) { - Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) - } else { - None - } - - val shuffleRead = - if (hasShuffleRead) { - Some(TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable, - shuffleReadBlockedTimeReadable, - shuffleReadSortable, - s"$shuffleReadReadable / $shuffleReadRecords", - shuffleReadRemoteSortable, - shuffleReadRemoteReadable - )) - } else { - None - } - - val shuffleWrite = - if (hasShuffleWrite) { - Some(TaskTableRowShuffleWriteData( - writeTimeSortable, - writeTimeReadable, - shuffleWriteSortable, - s"$shuffleWriteReadable / $shuffleWriteRecords" - )) - } else { - None - } - - val bytesSpilled = - if (hasBytesSpilled) { - Some(TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable, - memoryBytesSpilledReadable, - diskBytesSpilledSortable, - diskBytesSpilledReadable - )) - } else { - None - } - - new TaskTableRowData( - info.index, - info.taskId, - info.attempt, - info.speculative, - info.status, - info.taskLocality.toString, - info.executorId, - info.host, - info.launchTime.getTime(), - duration, - formatDuration, - schedulerDelay, - taskDeserializationTime, - gcTime, - serializationTime, - gettingResultTime, - peakExecutionMemoryUsed, - if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, - input, - output, - shuffleRead, - shuffleWrite, - bytesSpilled, - info.errorMessage.getOrElse(""), - executorLogs(info.executorId)) + _tasksToShow } - private def executorLogs(id: String): Map[String, String] = { + def tasks: Seq[TaskData] = _tasksToShow + + def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } - /** - * Return Ordering according to sortColumn and desc - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering: Ordering[TaskTableRowData] = sortColumn match { - case "Index" => Ordering.by(_.index) - case "ID" => Ordering.by(_.taskId) - case "Attempt" => Ordering.by(_.attempt) - case "Status" => Ordering.by(_.status) - case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID" => Ordering.by(_.executorId) - case "Host" => Ordering.by(_.host) - case "Launch Time" => Ordering.by(_.launchTime) - case "Duration" => Ordering.by(_.duration) - case "Scheduler Delay" => Ordering.by(_.schedulerDelay) - case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) - case "GC Time" => Ordering.by(_.gcTime) - case "Result Serialization Time" => Ordering.by(_.serializationTime) - case "Getting Result Time" => Ordering.by(_.gettingResultTime) - case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) - case "Accumulators" => - if (hasAccumulators) { - Ordering.by(_.accumulators.get) - } else { - throw new IllegalArgumentException( - "Cannot sort by Accumulators because of no accumulators") - } - case "Input Size / Records" => - if (hasInput) { - Ordering.by(_.input.get.inputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Input Size / Records because of no inputs") - } - case "Output Size / Records" => - if (hasOutput) { - Ordering.by(_.output.get.outputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Output Size / Records because of no outputs") - } - // ShuffleRead - case "Shuffle Read Blocked Time" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") - } - case "Shuffle Read Size / Records" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") - } - case "Shuffle Remote Reads" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Remote Reads because of no shuffle reads") - } - // ShuffleWrite - case "Write Time" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.writeTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Write Time because of no shuffle writes") - } - case "Shuffle Write Size / Records" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") - } - // BytesSpilled - case "Shuffle Spill (Memory)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Memory) because of no spills") - } - case "Shuffle Spill (Disk)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Disk) because of no spills") - } - case "Errors" => Ordering.by(_.error) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } - } - } private[ui] class TaskPagedTable( - conf: SparkConf, + stage: StageData, basePath: String, - data: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskData] { + + import ApiHelper._ override def tableId: String = "task-table" @@ -1142,13 +718,7 @@ private[ui] class TaskPagedTable( override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( - data, - hasAccumulators, - hasInput, - hasOutput, - hasShuffleRead, - hasShuffleWrite, - hasBytesSpilled, + stage, currentTime, pageSize, sortColumn, @@ -1180,22 +750,22 @@ private[ui] class TaskPagedTable( ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (hasShuffleRead) { + {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + {if (hasShuffleRead(stage)) { Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), ("Shuffle Read Size / Records", ""), ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) } else { Nil }} ++ - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) } else { Nil @@ -1237,7 +807,17 @@ private[ui] class TaskPagedTable(
    {headerRow} } - def row(task: TaskTableRowData): Seq[Node] = { + def row(task: TaskData): Seq[Node] = { + def formatDuration(value: Option[Long], hideZero: Boolean = false): String = { + value.map { v => + if (v > 0 || !hideZero) UIUtils.formatDuration(v) else "" + }.getOrElse("") + } + + def formatBytes(value: Option[Long]): String = { + Utils.bytesToString(value.getOrElse(0L)) + } + @@ -1249,62 +829,98 @@ private[ui] class TaskPagedTable(
    {task.host}
    { - task.logs.map { + dataSource.executorLogs(task.executorId).map { case (logName, logUrl) => } }
    - - + + - {if (task.accumulators.nonEmpty) { - + {if (hasAccumulators(stage)) { + accumulatorsInfo(task) }} - {if (task.input.nonEmpty) { - + {if (hasInput(stage)) { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead) + val records = m.inputMetrics.recordsRead + + } }} - {if (task.output.nonEmpty) { - + {if (hasOutput(stage)) { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten) + val records = m.outputMetrics.recordsWritten + + } }} - {if (task.shuffleRead.nonEmpty) { + {if (hasShuffleRead(stage)) { - + }} - {if (task.shuffleWrite.nonEmpty) { - - + {if (hasShuffleWrite(stage)) { + + }} - {if (task.bytesSpilled.nonEmpty) { - - + {if (hasBytesSpilled(stage)) { + + }} - {errorMessageCell(task.error)} + {errorMessageCell(task.errorMessage.getOrElse(""))} } + private def accumulatorsInfo(task: TaskData): Seq[Node] = { + task.accumulatorUpdates.map { acc => + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + } + } + + private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { + task.taskMetrics.map(fn).getOrElse(Nil) + } + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default @@ -1333,6 +949,36 @@ private[ui] class TaskPagedTable( private object ApiHelper { + + private val COLUMN_TO_INDEX = Map( + "ID" -> null.asInstanceOf[String], + "Index" -> TaskIndexNames.TASK_INDEX, + "Attempt" -> TaskIndexNames.ATTEMPT, + "Status" -> TaskIndexNames.STATUS, + "Locality Level" -> TaskIndexNames.LOCALITY, + "Executor ID / Host" -> TaskIndexNames.EXECUTOR, + "Launch Time" -> TaskIndexNames.LAUNCH_TIME, + "Duration" -> TaskIndexNames.DURATION, + "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, + "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, + "GC Time" -> TaskIndexNames.GC_TIME, + "Result Serialization Time" -> TaskIndexNames.SER_TIME, + "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, + "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, + "Accumulators" -> TaskIndexNames.ACCUMULATORS, + "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, + "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, + "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, + "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, + "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, + "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, + "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, + "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, + "Errors" -> TaskIndexNames.ERROR) + + def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 @@ -1349,4 +995,11 @@ private object ApiHelper { metrics.localBytesRead + metrics.remoteBytesRead } + def indexName(sortColumn: String): Option[String] = { + COLUMN_TO_INDEX.get(sortColumn) match { + case Some(v) => Option(v) + case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn") + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index f8e27703c0def..5c42ac1d87f4c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 6.0, 53.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index a28bda16a956e..e6b705989cc97 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index ede3eaed1d1d2..788f28cf7b365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b8c84e24c2c3f..ca66b6b9db890 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -213,45 +213,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.taskId === task.taskId) + assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptNumber) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - - val runtime = Array[AnyRef](stages.head.stageId: JInteger, - stages.head.attemptNumber: JInteger, - -1L: JLong) - assert(Arrays.equals(wrapper.runtime, runtime)) - - assert(wrapper.info.index === task.index) - assert(wrapper.info.attempt === task.attemptNumber) - assert(wrapper.info.launchTime === new Date(task.launchTime)) - assert(wrapper.info.executorId === task.executorId) - assert(wrapper.info.host === task.host) - assert(wrapper.info.status === task.status) - assert(wrapper.info.taskLocality === task.taskLocality.toString()) - assert(wrapper.info.speculative === task.speculative) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.index === task.index) + assert(wrapper.attempt === task.attemptNumber) + assert(wrapper.launchTime === task.launchTime) + assert(wrapper.executorId === task.executorId) + assert(wrapper.host === task.host) + assert(wrapper.status === task.status) + assert(wrapper.taskLocality === task.taskLocality.toString()) + assert(wrapper.speculative === task.speculative) } } - // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. - s1Tasks.foreach { task => - val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), - Some(1L), None, true, false, None) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( - task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) - } + // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code. + // The tasks are distributed among the two executors, so the executor-level metrics should + // hold half of the cummulative value of the metric being updated. + Seq(1L, 2L).foreach { value => + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(value), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) + } - check[StageDataWrapper](key(stages.head)) { stage => - assert(stage.info.memoryBytesSpilled === s1Tasks.size) - } + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size * value) + } - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") - .first(key(stages.head)).last(key(stages.head)).asScala.toSeq - assert(execs.size > 0) - execs.foreach { exec => - assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2) + } } // Fail one of the tasks, re-start it. @@ -278,13 +275,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](s1Tasks.head.taskId) { task => - assert(task.info.status === s1Tasks.head.status) - assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + assert(task.status === s1Tasks.head.status) + assert(task.errorMessage == Some(TaskResultLost.toErrorString)) } check[TaskDataWrapper](reattempt.taskId) { task => - assert(task.info.index === s1Tasks.head.index) - assert(task.info.attempt === reattempt.attemptNumber) + assert(task.index === s1Tasks.head.index) + assert(task.attempt === reattempt.attemptNumber) } // Kill one task, restart it. @@ -306,8 +303,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](killed.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some("killed")) + assert(task.index === killed.index) + assert(task.errorMessage === Some("killed")) } // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. @@ -334,8 +331,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](denied.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some(denyReason.toErrorString)) + assert(task.index === killed.index) + assert(task.errorMessage === Some(denyReason.toErrorString)) } // Start a new attempt. @@ -373,10 +370,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { pending.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.errorMessage === None) - assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) - assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) - assert(wrapper.info.duration === Some(task.duration)) + assert(wrapper.errorMessage === None) + assert(wrapper.executorCpuTime === 2L) + assert(wrapper.executorRunTime === 4L) + assert(wrapper.duration === task.duration) } } @@ -894,6 +891,23 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 3) assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + val dropped = stages.drop(1).head + + // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be + // calculcated, we need at least one finished task. + time += 1 + val task = createTasks(1, Array("1")).head + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + + time += 1 + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + "taskType", Success, task, null)) + + new AppStatusStore(store) + .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) + stages.drop(1).foreach { s => time += 1 s.completionTime = Some(time) @@ -905,6 +919,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { intercept[NoSuchElementException] { store.read(classOf[StageDataWrapper], Array(2, 0)) } + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") time += 1 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala new file mode 100644 index 0000000000000..92f90f3d96ddf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.TaskMetricDistributions +import org.apache.spark.util.Distribution +import org.apache.spark.util.kvstore._ + +class AppStatusStoreSuite extends SparkFunSuite { + + private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0) + private val stageId = 1 + private val attemptId = 1 + + test("quantile calculation: 1 task") { + compareQuantiles(1, uiQuantiles) + } + + test("quantile calculation: few tasks") { + compareQuantiles(4, uiQuantiles) + } + + test("quantile calculation: more tasks") { + compareQuantiles(100, uiQuantiles) + } + + test("quantile calculation: lots of tasks") { + compareQuantiles(4096, uiQuantiles) + } + + test("quantile calculation: custom quantiles") { + compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99)) + } + + test("quantile cache") { + val store = new InMemoryStore() + (0 until 4096).foreach { i => store.write(newTaskData(i)) } + + val appStore = new AppStatusStore(store) + + appStore.taskSummary(stageId, attemptId, Array(0.13d)) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13")) + } + + appStore.taskSummary(stageId, attemptId, Array(0.25d)) + val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + + // Add a new task to force the cached quantile to be evicted, and make sure it's updated. + store.write(newTaskData(4096)) + appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d)) + + val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + assert(d1.taskCount != d2.taskCount) + + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50")) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73")) + } + + assert(store.count(classOf[CachedQuantile]) === 2) + } + + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { + val store = new InMemoryStore() + val values = (0 until count).map { i => + val task = newTaskData(i) + store.write(task) + i.toDouble + }.toArray + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get + val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted) + + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + + private def newTaskData(i: Int): TaskDataWrapper = { + new TaskDataWrapper( + i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, stageId, attemptId) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 661d0d48d2f37..0aeddf730cd35 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.config._ import org.apache.spark.ui.jobs.{StagePage, StagesTab} class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -35,15 +36,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"" * 5)) } @@ -52,7 +51,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. */ - private def renderStagePage(conf: SparkConf): Seq[Node] = { + private def renderStagePage(): Seq[Node] = { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) val statusStore = AppStatusStore.createLiveStore(conf) val listener = statusStore.listener.get diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7bdd3fac773a3..e2fa5754afaee 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -93,7 +93,7 @@ This file is divided into 3 sections: - + From 0552c36e02434c60dad82024334d291f6008b822 Mon Sep 17 00:00:00 2001 From: wuyi5 Date: Thu, 11 Jan 2018 22:17:15 +0900 Subject: [PATCH 619/712] [SPARK-22967][TESTS] Fix VersionSuite's unit tests by change Windows path into URI path ## What changes were proposed in this pull request? Two unit test will fail due to Windows format path: 1.test(s"$version: read avro file containing decimal") ``` org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.lang.IllegalArgumentException: Can not create a Path from an empty string); ``` 2.test(s"$version: SPARK-17920: Insert into/overwrite avro table") ``` Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; org.apache.spark.sql.AnalysisException: Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; ``` This pr fix these two unit test by change Windows path into URI path. ## How was this patch tested? Existed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wuyi5 Closes #20199 from Ngone51/SPARK-22967. --- .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff90e9dda5f7c..e64389e56b5a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -811,7 +811,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: read avro file containing decimal") { val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile) + val location = new File(url.getFile).toURI.toString val tableName = "tab1" val avroSchema = @@ -851,6 +851,8 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) withTempDir { dir => val avroSchema = """ @@ -875,10 +877,10 @@ class VersionsSuite extends SparkFunSuite with Logging { val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() - val schemaPath = schemaFile.getCanonicalPath + val schemaPath = schemaFile.toURI.toString val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).getCanonicalPath + val srcLocation = new File(url.getFile).toURI.toString val destTableName = "tab1" val srcTableName = "tab2" From 76892bcf2c08efd7e9c5b16d377e623d82fe695e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 21:32:36 +0800 Subject: [PATCH 620/712] [SPARK-23000][TEST-HADOOP2.6] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite ## What changes were proposed in this pull request? The Spark 2.3 branch still failed due to the flaky test suite `DataSourceWithHiveMetastoreCatalogSuite `. https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ Although https://github.com/apache/spark/pull/20207 is unable to reproduce it in Spark 2.3, it sounds like the current DB of Spark's Catalog is changed based on the following stacktrace. Thus, we just need to reset it. ``` [info] DataSourceWithHiveMetastoreCatalogSuite: 02:40:39.486 ERROR org.apache.hadoop.hive.ql.parse.CalcitePlanner: org.apache.hadoop.hive.ql.parse.SemanticException: Line 1:14 Table not found 't' at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1594) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1545) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.genResolvedParseTree(SemanticAnalyzer.java:10077) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.analyzeInternal(SemanticAnalyzer.java:10128) at org.apache.hadoop.hive.ql.parse.CalcitePlanner.analyzeInternal(CalcitePlanner.java:209) at org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.analyze(BaseSemanticAnalyzer.java:227) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:424) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:308) at org.apache.hadoop.hive.ql.Driver.compileInternal(Driver.java:1122) at org.apache.hadoop.hive.ql.Driver.runInternal(Driver.java:1170) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1059) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1049) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:694) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1$$anonfun$apply$mcV$sp$3.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:185) at org.apache.spark.sql.test.SQLTestUtilsBase$class.withTable(SQLTestUtils.scala:273) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:139) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:186) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:183) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:289) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:196) at org.scalatest.FunSuite.runTest(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:396) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:384) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:384) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:379) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:461) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:229) at org.scalatest.FunSuite.runTests(FunSuite.scala:1560) at org.scalatest.Suite$class.run(Suite.scala:1147) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.SuperEngine.runImpl(Engine.scala:521) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:233) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:210) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:314) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:480) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20218 from gatorsmile/testFixAgain. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index cf4ce83124d88..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -148,6 +148,7 @@ class DataSourceWithHiveMetastoreCatalogSuite override def beforeAll(): Unit = { super.beforeAll() + sparkSession.sessionState.catalog.reset() sparkSession.metadataHive.reset() } From b46e58b74c82dac37b7b92284ea3714919c5a886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 22:33:42 +0900 Subject: [PATCH 621/712] [SPARK-19732][FOLLOW-UP] Document behavior changes made in na.fill and fillna ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18164 introduces the behavior changes. We need to document it. ## How was this patch tested? N/A Author: gatorsmile Closes #20234 from gatorsmile/docBehaviorChange. --- docs/sql-programming-guide.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 72f79d6909ecc..258c769ff593b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1788,12 +1788,10 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - - - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. ## Upgrading From Spark SQL 2.1 to 2.2 From 6d230dccf65300651f989392159d84bfaf08f18f Mon Sep 17 00:00:00 2001 From: FanDonglai Date: Thu, 11 Jan 2018 09:06:40 -0600 Subject: [PATCH 622/712] Update PageRank.scala ## What changes were proposed in this pull request? Hi, acording to code below, "if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0)" I think the comment can be wrong ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: FanDonglai Closes #20220 from ddna1021/master. --- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index fd7b7f7c1c487..ebd65e8320e5c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -303,7 +303,7 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the pagerankGraph with each edge attribute - // having weight 1/outDegree and each vertex with attribute 1.0. + // having weight 1/outDegree and each vertex with attribute 0. val pagerankGraph: Graph[(Double, Double), Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { From 0b2eefb674151a0af64806728b38d9410da552ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 10:37:35 -0800 Subject: [PATCH 623/712] [SPARK-22994][K8S] Use a single image for all Spark containers. This change allows a user to submit a Spark application on kubernetes having to provide a single image, instead of one image for each type of container. The image's entry point now takes an extra argument that identifies the process that is being started. The configuration still allows the user to provide different images for each container type if they so desire. On top of that, the entry point was simplified a bit to share more code; mainly, the same env variable is used to propagate the user-defined classpath to the different containers. Aside from being modified to match the new behavior, the 'build-push-docker-images.sh' script was renamed to 'docker-image-tool.sh' to more closely match its purpose; the old name was a little awkward and now also not entirely correct, since there is a single image. It was also moved to 'bin' since it's not necessarily an admin tool. Docs have been updated to match the new behavior. Tested locally with minikube. Author: Marcelo Vanzin Closes #20192 from vanzin/SPARK-22994. --- .../docker-image-tool.sh | 68 ++++++------- docs/running-on-kubernetes.md | 58 +++++------ .../org/apache/spark/deploy/k8s/Config.scala | 17 ++-- .../apache/spark/deploy/k8s/Constants.scala | 3 +- .../deploy/k8s/InitContainerBootstrap.scala | 1 + .../steps/BasicDriverConfigurationStep.scala | 3 +- .../cluster/k8s/ExecutorPodFactory.scala | 3 +- .../DriverConfigOrchestratorSuite.scala | 12 +-- .../BasicDriverConfigurationStepSuite.scala | 4 +- ...InitContainerConfigOrchestratorSuite.scala | 4 +- .../cluster/k8s/ExecutorPodFactorySuite.scala | 4 +- .../src/main/dockerfiles/driver/Dockerfile | 35 ------- .../src/main/dockerfiles/executor/Dockerfile | 35 ------- .../dockerfiles/init-container/Dockerfile | 24 ----- .../main/dockerfiles/spark-base/entrypoint.sh | 37 ------- .../{spark-base => spark}/Dockerfile | 10 +- .../src/main/dockerfiles/spark/entrypoint.sh | 97 +++++++++++++++++++ 17 files changed, 189 insertions(+), 226 deletions(-) rename sbin/build-push-docker-images.sh => bin/docker-image-tool.sh (63%) delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile delete mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh rename resource-managers/kubernetes/docker/src/main/dockerfiles/{spark-base => spark}/Dockerfile (87%) create mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh diff --git a/sbin/build-push-docker-images.sh b/bin/docker-image-tool.sh similarity index 63% rename from sbin/build-push-docker-images.sh rename to bin/docker-image-tool.sh index b9532597419a5..071406336d1b1 100755 --- a/sbin/build-push-docker-images.sh +++ b/bin/docker-image-tool.sh @@ -24,29 +24,11 @@ function error { exit 1 } -# Detect whether this is a git clone or a Spark distribution and adjust paths -# accordingly. if [ -z "${SPARK_HOME}" ]; then SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi . "${SPARK_HOME}/bin/load-spark-env.sh" -if [ -f "$SPARK_HOME/RELEASE" ]; then - IMG_PATH="kubernetes/dockerfiles" - SPARK_JARS="jars" -else - IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" - SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -if [ ! -d "$IMG_PATH" ]; then - error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." -fi - -declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ - [spark-executor]="$IMG_PATH/executor/Dockerfile" \ - [spark-init]="$IMG_PATH/init-container/Dockerfile" ) - function image_ref { local image="$1" local add_repo="${2:-1}" @@ -60,35 +42,49 @@ function image_ref { } function build { - docker build \ - --build-arg "spark_jars=$SPARK_JARS" \ - --build-arg "img_path=$IMG_PATH" \ - -t spark-base \ - -f "$IMG_PATH/spark-base/Dockerfile" . - for image in "${!path[@]}"; do - docker build -t "$(image_ref $image)" -f ${path[$image]} . - done + local BUILD_ARGS + local IMG_PATH + + if [ ! -f "$SPARK_HOME/RELEASE" ]; then + # Set image build arguments accordingly if this is a source repo and not a distribution archive. + IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles + BUILD_ARGS=( + --build-arg + img_path=$IMG_PATH + --build-arg + spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + ) + else + # Not passed as an argument to docker, but used to validate the Spark directory. + IMG_PATH="kubernetes/dockerfiles" + fi + + if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." + fi + + docker build "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$IMG_PATH/spark/Dockerfile" . } function push { - for image in "${!path[@]}"; do - docker push "$(image_ref $image)" - done + docker push "$(image_ref spark)" } function usage { cat < -t my-tag build - ./sbin/build-push-docker-images.sh -r -t my-tag push - -Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before -building using the supplied script, or manually. + ./bin/docker-image-tool.sh -r -t my-tag build + ./bin/docker-image-tool.sh -r -t my-tag push ## Cluster Mode @@ -79,8 +76,7 @@ $ bin/spark-submit \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.container.image= \ local:///path/to/examples.jar ``` @@ -126,13 +122,7 @@ Those dependencies can be added to the classpath by referencing them with `local ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. This requires users to specify the container -image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users -simply add the following option to the `spark-submit` command to specify the init-container image: - -``` ---conf spark.kubernetes.initContainer.image= -``` +the dependencies so the driver and executor containers can use them locally. The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and `spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., @@ -147,9 +137,7 @@ $ bin/spark-submit \ --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ - --conf spark.kubernetes.initContainer.image= + --conf spark.kubernetes.container.image= \ https://path/to/examples.jar ``` @@ -322,21 +310,27 @@ specific to Spark on Kubernetes. - + + + + + + - + @@ -643,9 +637,9 @@ specific to Spark on Kubernetes. - + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e5d79d9a9d9da..471196ac0e3f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -29,17 +29,23 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") + val CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.container.image") + .doc("Container image to use for Spark containers. Individual container types " + + "(e.g. driver or executor) can also be configured to use different images if desired, " + + "by setting the container type-specific image name.") + .stringConf + .createOptional + val DRIVER_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.driver.container.image") .doc("Container image to use for the driver.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val EXECUTOR_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.executor.container.image") .doc("Container image to use for the executors.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val CONTAINER_IMAGE_PULL_POLICY = ConfigBuilder("spark.kubernetes.container.image.pullPolicy") @@ -148,8 +154,7 @@ private[spark] object Config extends Logging { val INIT_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.initContainer.image") .doc("Image for the driver and executor's init-container for downloading dependencies.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val INIT_CONTAINER_MOUNT_TIMEOUT = ConfigBuilder("spark.kubernetes.mountDependencies.timeout") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 111cb2a3b75e5..9411956996843 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -60,10 +60,9 @@ private[spark] object Constants { val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala index dfeccf9e2bd1c..f6a57dfe00171 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -77,6 +77,7 @@ private[spark] class InitContainerBootstrap( .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) .endVolumeMount() .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs("init") .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index eca46b84c6066..164e2e5594778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -66,7 +66,7 @@ private[spark] class BasicDriverConfigurationStep( override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => new EnvVarBuilder() - .withName(ENV_SUBMIT_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(classPath) .build() } @@ -133,6 +133,7 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits("memory", driverMemoryLimitQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() + .addToArgs("driver") .build() val baseDriverPod = new PodBuilder(driverSpec.driverPod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index bcacb3934d36a..141bd2827e7c5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -128,7 +128,7 @@ private[spark] class ExecutorPodFactory( .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() - .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(cp) .build() } @@ -181,6 +181,7 @@ private[spark] class ExecutorPodFactory( .endResources() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) + .addToArgs("executor") .build() val executorPod = new PodBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index f193b1f4d3664..65274c6f50e01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -34,8 +34,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, @@ -55,8 +54,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { } test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, @@ -75,8 +73,8 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with an init-container.") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( @@ -98,7 +96,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with driver secrets to mount") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index 8ee629ac8ddc1..b136f2c02ffba 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -47,7 +47,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") + .set(CONTAINER_IMAGE, "spark-driver:latest") .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") @@ -79,7 +79,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .asScala .map(env => (env.getName, env.getValue)) .toMap - assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") + assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala index 20f2e5bc15df3..09b42e4484d86 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -40,7 +40,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including basic configuration step") { val sparkConf = new SparkConf(true) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) val orchestrator = new InitContainerConfigOrchestrator( @@ -59,7 +59,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including step to mount user-specified secrets") { val sparkConf = new SparkConf(false) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7cfbe54c95390..a3c615be031d2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -54,7 +54,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_CONTAINER_IMAGE, executorImage) + .set(CONTAINER_IMAGE, executorImage) } test("basic executor pod has reasonable defaults") { @@ -107,7 +107,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkEnv(executor, Map("SPARK_JAVA_OPT_0" -> "foo=bar", - "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + ENV_CLASSPATH -> "bar=baz", "qux" -> "quux")) checkOwnerReferences(executor, driverPodUid) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile deleted file mode 100644 index 45fbcd9cd0deb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile deleted file mode 100644 index 0f806cf7e148e..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile deleted file mode 100644 index 047056ab2633b..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# If this docker file is being used in the context of building your images from a Spark distribution, the docker build -# command should be invoked from the top level directory of the Spark distribution. E.g.: -# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . - -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh deleted file mode 100755 index 82559889f4beb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# echo commands to the terminal output -set -ex - -# Check whether there is a passwd entry for the container UID -myuid=$(id -u) -mygid=$(id -g) -uidentry=$(getent passwd $myuid) - -# If there is no passwd entry for the container UID, attempt to create one -if [ -z "$uidentry" ] ; then - if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd - else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" - fi -fi - -# Execute the container CMD under tini for better hygiene -/sbin/tini -s -- "$@" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile similarity index 87% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile rename to resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index da1d6b9e161cc..491b7cf692478 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,15 +17,15 @@ FROM openjdk:8-alpine -ARG spark_jars -ARG img_path +ARG spark_jars=jars +ARG img_path=kubernetes/dockerfiles # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -41,7 +41,9 @@ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY ${img_path}/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh new file mode 100755 index 0000000000000..0c28c75857871 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +uidentry=$(getent passwd $myuid) + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +SPARK_K8S_CMD="$1" +if [ -z "$SPARK_K8S_CMD" ]; then + echo "No command to execute has been provided." 1>&2 + exit 1 +fi +shift 1 + +SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" +env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +fi +if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then + cp -R "$SPARK_MOUNTED_FILES_DIR/." . +fi + +case "$SPARK_K8S_CMD" in + driver) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_DRIVER_JAVA_OPTS[@]}" + -cp "$SPARK_CLASSPATH" + -Xms$SPARK_DRIVER_MEMORY + -Xmx$SPARK_DRIVER_MEMORY + -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS + $SPARK_DRIVER_CLASS + $SPARK_DRIVER_ARGS + ) + ;; + + executor) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + -Xms$SPARK_EXECUTOR_MEMORY + -Xmx$SPARK_EXECUTOR_MEMORY + -cp "$SPARK_CLASSPATH" + org.apache.spark.executor.CoarseGrainedExecutorBackend + --driver-url $SPARK_DRIVER_URL + --executor-id $SPARK_EXECUTOR_ID + --cores $SPARK_EXECUTOR_CORES + --app-id $SPARK_APPLICATION_ID + --hostname $SPARK_EXECUTOR_POD_IP + ) + ;; + + init) + CMD=( + "$SPARK_HOME/bin/spark-class" + "org.apache.spark.deploy.k8s.SparkPodInitContainer" + "$@" + ) + ;; + + *) + echo "Unknown command: $SPARK_K8S_CMD" 1>&2 + exit 1 +esac + +# Execute the container CMD under tini for better hygiene +exec /sbin/tini -s -- "${CMD[@]}" From 6f7aaed805070d29dcba32e04ca7a1f581fa54b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 11 Jan 2018 10:52:12 -0800 Subject: [PATCH 624/712] [SPARK-22908] Add kafka source and sink for continuous processing. ## What changes were proposed in this pull request? Add kafka source and sink for continuous processing. This involves two small changes to the execution engine: * Bring data reader close() into the normal data reader thread to avoid thread safety issues. * Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20096 from jose-torres/continuous-kafka. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 +++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ++++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 64 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 +++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1531 insertions(+), 383 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..928379544758c --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs = Long.MaxValue, + failOnDataLoss) + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..27da76068a66f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..dfc97b1c38bb5 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..e713e6695d2bd --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..d66908f86ccc7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) + .option("startingOffsets", s"earliest") .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock + .load() - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() @@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..e700aa4f9aea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -201,6 +200,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 11 Jan 2018 13:57:15 -0800 Subject: [PATCH 625/712] [SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline ## What changes were proposed in this pull request? Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes. ## How was this patch tested? Unit tests. Author: Bago Amirbekian Closes #20238 from MrBago/rFormulaVectorSize. --- R/pkg/R/mllib_utils.R | 1 + .../apache/spark/ml/feature/RFormula.scala | 18 +++++++-- .../spark/ml/feature/RFormulaSuite.scala | 37 ++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..23dda42c325be 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,3 +130,4 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339f8b487..f384ffbf578bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6dfa..f3f4b5a3d0233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result3.collect() === expected3.collect()) assert(result4.collect() === expected4.collect()) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } } From b5042d75c2faa5f15bc1e160d75f06dfdd6eea37 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 11 Jan 2018 16:20:30 -0800 Subject: [PATCH 626/712] [SPARK-23008][ML] OnehotEncoderEstimator python API ## What changes were proposed in this pull request? OnehotEncoderEstimator python API. ## How was this patch tested? doctest Author: WeichenXu Closes #20209 from WeichenXu123/ohe_py. --- python/pyspark/ml/feature.py | 113 ++++++++++++++++++ .../ml/param/_shared_params_code_gen.py | 1 + python/pyspark/ml/param/shared.py | 23 ++++ 3 files changed, 137 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 13bf95cce40be..b963e45dd7cff 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -45,6 +45,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', + 'OneHotEncoderEstimator', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -1641,6 +1642,118 @@ def getDropLast(self): return self.getOrDefault(self.dropLast) +@inherit_doc +class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + A one-hot encoder that maps a column of category indices to a column of binary vectors, with + at most a single one-value per row that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to an output vector of + `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via `dropLast`), + because it makes the vector entries sum up to one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + + Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. + + When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + vector. + + Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output + cols come in pairs, specified by the order in the arrays, and each pair is treated + independently. + + See `StringIndexer` for converting categorical values into category indices + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> model = ohe.fit(df) + >>> model.transform(df).head().output + SparseVector(2, {0: 1.0}) + >>> ohePath = temp_path + "/oheEstimator" + >>> ohe.save(ohePath) + >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE.getInputCols() == ohe.getInputCols() + True + >>> modelPath = temp_path + "/ohe-model" + >>> model.save(modelPath) + >>> loadedModel = OneHotEncoderModel.load(modelPath) + >>> loadedModel.categorySizes == model.categorySizes + True + + .. versionadded:: 2.3.0 + """ + + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " + + "transform(). Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error). Note that this Param " + + "is only used during transform; during fitting, invalid data will " + + "result in an error.", + typeConverter=TypeConverters.toString) + + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + """ + super(OneHotEncoderEstimator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + self._setDefault(handleInvalid="error", dropLast=True) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + Sets params for this OneHotEncoderEstimator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("2.3.0") + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + def _create_model(self, java_model): + return OneHotEncoderModel(java_model) + + +class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`OneHotEncoderEstimator`. + + .. versionadded:: 2.3.0 + """ + + @property + @since("2.3.0") + def categorySizes(self): + """ + Original number of categories for each feature being encoded. + The array contains one value for each input column, in order. + """ + return self._call_java("categorySizes") + + @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 1d0f60acc6983..db951d81de1e7 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -119,6 +119,7 @@ def get$Name(self): ("inputCol", "input column name.", None, "TypeConverters.toString"), ("inputCols", "input column names.", None, "TypeConverters.toListString"), ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("outputCols", "output column names.", None, "TypeConverters.toListString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 813f7a59f3fd1..474c38764e5a1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -256,6 +256,29 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString) + + def __init__(self): + super(HasOutputCols, self).__init__() + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features. From cbe7c6fbf9dc2fc422b93b3644c40d449a869eea Mon Sep 17 00:00:00 2001 From: ho3rexqj Date: Fri, 12 Jan 2018 15:27:00 +0800 Subject: [PATCH 627/712] [SPARK-22986][CORE] Use a cache to avoid instantiating multiple instances of broadcast variable values When resources happen to be constrained on an executor the first time a broadcast variable is instantiated it is persisted to disk by the BlockManager. Consequently, every subsequent call to TorrentBroadcast::readBroadcastBlock from other instances of that broadcast variable spawns another instance of the underlying value. That is, broadcast variables are spawned once per executor **unless** memory is constrained, in which case every instance of a broadcast variable is provided with a unique copy of the underlying value. This patch fixes the above by explicitly caching the underlying values using weak references in a ReferenceMap. Author: ho3rexqj Closes #20183 from ho3rexqj/fix/cache-broadcast-values. --- .../spark/broadcast/BroadcastManager.scala | 6 ++ .../spark/broadcast/TorrentBroadcast.scala | 72 +++++++++++-------- .../spark/broadcast/BroadcastSuite.scala | 34 +++++++++ 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..8d7a4a353a792 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging @@ -52,6 +54,10 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) + private[broadcast] val cachedValues = { + new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + } + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7aecd3c9668ea..e125095cf4777 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { - setConf(SparkEnv.get.conf) - val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId) match { - case Some(blockResult) => - if (blockResult.data.hasNext) { - val x = blockResult.data.next().asInstanceOf[T] - releaseLock(broadcastId) - x - } else { - throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") - } - case None => - logInfo("Started reading broadcast variable " + id) - val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks() - logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + + Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + + if (x != null) { + broadcastCache.put(broadcastId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } - obj - } finally { - blocks.foreach(_.dispose()) - } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + + if (obj != null) { + broadcastCache.put(broadcastId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } } } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 159629825c677..9ad2e9a5e74ac 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } + test("One broadcast value instance per executor") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + + test("One broadcast value instance per executor when memory is constrained") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * From a7d98d53ceaf69cabaecc6c9113f17438c4e61f6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 12 Jan 2018 11:27:02 +0200 Subject: [PATCH 628/712] [SPARK-23008][ML][FOLLOW-UP] mark OneHotEncoder python API deprecated ## What changes were proposed in this pull request? mark OneHotEncoder python API deprecated ## How was this patch tested? N/A Author: WeichenXu Closes #20241 from WeichenXu123/mark_ohe_deprecated. --- python/pyspark/ml/feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b963e45dd7cff..eb79b193103e2 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1578,6 +1578,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. The output vectors are sparse. + .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to + :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. + .. seealso:: :py:class:`StringIndexer` for converting categorical values into From 505086806997b4331d4a8c2fc5e08345d869a23c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 18:04:44 +0800 Subject: [PATCH 629/712] [SPARK-23025][SQL] Support Null type in scala reflection ## What changes were proposed in this pull request? Add support for `Null` type in the `schemaFor` method for Scala reflection. ## How was this patch tested? Added UT Author: Marco Gaido Closes #20219 from mgaido91/SPARK-23025. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++++ .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 65040f1af4b04..9a4bf0075a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection { private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { + case t if t <:< definitions.NullTpe => NullType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -712,6 +713,9 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 23e866cdf4917..8c3db48a01f12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -356,4 +356,13 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-23025: schemaFor should support Null type") { + val schema = schemaFor[(Int, Null)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", NullType, nullable = true))), + nullable = true)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d535896723bd5..54893c184642b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1441,6 +1441,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-23025: Add support for null type in scala reflection") { + val data = Seq(("a", null)) + checkDataset(data.toDS(), data: _*) + } } case class SingleData(id: Int) From f5300fbbe370af3741560f67bfb5ae6f0b0f7bb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20Beaup=C3=A8re?= Date: Fri, 12 Jan 2018 08:29:46 -0600 Subject: [PATCH 630/712] Update rdd-programming-guide.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Small typing correction - double word ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Matthias Beaupère Closes #20212 from matthiasbe/patch-1. --- docs/rdd-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 29af159510e46..2e29aef7f21a2 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -91,7 +91,7 @@ so C libraries like NumPy can be used. It also works with PyPy 2.3+. Python 2.6 support was removed in Spark 2.2.0. -Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including including it in your setup.py as: +Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: {% highlight python %} install_requires=[ From 651f76153f5e9b185aaf593161d40cabe7994fea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 13 Jan 2018 00:37:59 +0800 Subject: [PATCH 631/712] [SPARK-23028] Bump master branch version to 2.4.0-SNAPSHOT ## What changes were proposed in this pull request? This patch bumps the master branch version to `2.4.0-SNAPSHOT`. ## How was this patch tested? N/A Author: gatorsmile Closes #20222 from gatorsmile/bump24. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- dev/run-tests-jenkins.py | 4 ++-- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 43 files changed, 49 insertions(+), 44 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6d46c31906260..855eb5bf77f16 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.3.0 +Version: 2.4.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index b3b4239771bc3..a207dae5a74ff 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cf93d41cd77cf..8c148359c3029 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 18cbdadd224ab..8ca7733507f1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9968480ab7658..05335df61a664 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ec2db6e5bb88c..564e6583c909e 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2d59c71cc3757..2f04abe8c7e88 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f7e586ee777e1..ba127408e1c59 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a3772a2620088..1527854730394 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 0a5bd958fc9c5..9258a856028a0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 914eb93622d51..3960a0de62530 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 300m) - tests_timeout = "250m" + # must be less than the timeout configured on Jenkins (currently 350m) + tests_timeout = "300m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/docs/_config.yml b/docs/_config.yml index dcc211204d766..095fadb93fe5d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.3.0 +SPARK_VERSION: 2.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.4.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 1791dbaad775e..868110b8e35ef 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 485b562dce990..431339d412194 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 71016bc645ca7..7cd1ec4c9c09a 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 12630840e79dc..f810aa80e8780 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 87a09642405a7..498e88f665eb5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index d6f97316b326a..a742b8d6dbddb 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0c9f0aa765a39..16bbc6db641ca 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 6eb7ba5f0092d..3b124b2a69d50 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 786349474389b..41bc8b3e3ee1f 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 849c8b465f99e..6d1c4789f382d 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 48783d65826aa..37c7d1e604ec5 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 40a751a652fa9..4915893965595 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 36d555066b181..027157e53d511 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index cb30e4a4af4bc..fbe77fcb958d5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index aa36dd4774d86..8e424b1c50236 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index e9b46c4cf0ffa..912eb6b6d2a08 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 043d13609fd26..53286fe93478d 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a906c9e02cd4c..f07d7f24fd312 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 1b37164376460..d14594aa4ccb0 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b452f35c5ec1..32eb31f495979 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.4.x + lazy val v24excludes = v23excludes ++ Seq( + ) + // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-22897] Expose stageAttemptId in TaskContext @@ -1082,6 +1086,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 12dd53b9d2902..b9c2c4ced71d5 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0.dev0" +__version__ = "2.4.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index 1cb0098d0eca3..6f4a863c48bc7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 7d35aea8a4142..a62f271273465 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 70d0c1750b14e..3995d0afeb5f4 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 43a7ce95bd3de..37e25ceecb883 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9e2ced30407d4..839b929abd3cb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 93010c606cf45..744daa6079779 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3135a8a275dae..9f247f9224c75 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 66fad85ea0263..c55ba32fa458c 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index fea882ad11230..4497e53b65984 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 37427e8da62d8..242219e29f50f 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml From 7bd14cfd40500a0b6462cda647bdbb686a430328 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 12 Jan 2018 10:18:42 -0800 Subject: [PATCH 632/712] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up the java-lint errors (for v2.3.0-rc1 tag). Hopefully, this will be the final one. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[85] (sizes) LineLength: Line is longer than 100 characters (found 101). [ERROR] src/main/java/org/apache/spark/launcher/InProcessAppHandle.java:[20,8] (imports) UnusedImports: Unused import - java.io.IOException. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java:[41,9] (modifier) ModifierOrder: 'private' modifier out of order with the JLS suggestions. [ERROR] src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java:[464] (sizes) LineLength: Line is longer than 100 characters (found 102). ``` ## How was this patch tested? Manual. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` Author: Dongjoon Hyun Closes #20242 from dongjoon-hyun/fix_lint_java_2.3_rc1. --- .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 3 ++- .../java/org/apache/spark/launcher/InProcessAppHandle.java | 1 - .../spark/sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 3acfe3696cb1e..a9603c1aba051 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -82,7 +82,8 @@ public void free(MemoryBlock memory) { "page has already been freed"; assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : - "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + + "free()"; final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 0d6a73a3da3ed..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index f94c55d860304..b6e792274da11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -38,7 +38,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto private BytesColumnVector bytesData; private DecimalColumnVector decimalData; private TimestampColumnVector timestampData; - final private boolean isTimestamp; + private final boolean isTimestamp; private int batchSize; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4f8a31f185724..69a2904f5f3fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -461,7 +461,8 @@ public void testCircularReferenceBean() { public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); - String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)) + .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); From 54277398afbde92a38ba2802f4a7a3e5910533de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 11:25:37 -0800 Subject: [PATCH 633/712] [SPARK-22975][SS] MetricsReporter should not throw exception when there was no progress reported ## What changes were proposed in this pull request? `MetricsReporter ` assumes that there has been some progress for the query, ie. `lastProgress` is not null. If this is not true, as it might happen in particular conditions, a `NullPointerException` can be thrown. The PR checks whether there is a `lastProgress` and if this is not true, it returns a default value for the metrics. ## How was this patch tested? added UT Author: Marco Gaido Closes #20189 from mgaido91/SPARK-22975. --- .../execution/streaming/MetricsReporter.scala | 21 ++++++++--------- .../sql/streaming/StreamingQuerySuite.scala | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index b84e6ce64c611..66b11ecddf233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} - -import scala.collection.mutable - import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} -import org.apache.spark.util.Clock +import org.apache.spark.sql.streaming.StreamingQueryProgress /** * Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to @@ -39,14 +35,17 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group - registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) - registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) - - private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0) + registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) + registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + + private def registerGauge[T]( + name: String, + f: StreamingQueryProgress => T, + default: T): Unit = { synchronized { metricRegistry.register(name, new Gauge[T] { - override def getValue: T = f() + override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2fa4595dab376..76201c63a2701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -424,6 +424,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + + val gauges = sq.streamMetrics.metricRegistry.getGauges + assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + sq.stop() + } + } + } + test("input row calculation with mixed batch and streaming sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 55dbfbca37ce4c05f83180777ba3d4fe2d96a02e Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 12 Jan 2018 15:00:00 -0800 Subject: [PATCH 634/712] Revert "[SPARK-22908] Add kafka source and sink for continuous processing." This reverts commit 6f7aaed805070d29dcba32e04ca7a1f581fa54b9. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 --------- .../sql/kafka010/KafkaContinuousWriter.scala | 119 ----- .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +--- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 +-- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ------------------ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ---- .../sql/kafka010/KafkaContinuousTest.scala | 64 --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 ++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 +-- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 383 insertions(+), 1531 deletions(-) delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala deleted file mode 100644 index 928379544758c..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} - -import org.apache.kafka.clients.consumer.ConsumerRecord -import org.apache.kafka.common.TopicPartition - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -/** - * A [[ContinuousReader]] for data from kafka. - * - * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be - * read by per-task consumers generated later. - * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which - * are not Kafka consumer params. - * @param metadataPath Path to a directory this reader can use for writing metadata. - * @param initialOffsets The Kafka offsets to start reading data at. - * @param failOnDataLoss Flag indicating whether reading should fail in data loss - * scenarios, where some offsets after the specified initial ones can't be - * properly read. - */ -class KafkaContinuousReader( - offsetReader: KafkaOffsetReader, - kafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], - metadataPath: String, - initialOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext - - // Initialized when creating read tasks. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets - } - } - - override def getStartOffset(): Offset = offset - - override def deserializeOffset(json: String): Offset = { - KafkaSourceOffset(JsonUtils.partitionOffsets(json)) - } - - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - - startOffsets.toSeq.map { - case (topicPartition, start) => - KafkaContinuousReadTask( - topicPartition, start, kafkaParams, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] - }.asJava - } - - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit = synchronized { - offsetReader.close() - } - - override def commit(end: Offset): Unit = {} - - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { - val mergedMap = offsets.map { - case KafkaSourcePartitionOffset(p, o) => Map(p -> o) - }.reduce(_ ++ _) - KafkaSourceOffset(mergedMap) - } - - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions - } - - override def toString(): String = s"KafkaSource[$offsetReader]" - - /** - * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. - * Otherwise, just log a warning. - */ - private def reportDataLoss(message: String): Unit = { - if (failOnDataLoss) { - throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") - } else { - logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") - } - } -} - -/** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. - * - * @param topicPartition The (topic, partition) pair this task is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -case class KafkaContinuousReadTask( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) - } -} - -/** - * A per-task data reader for continuous Kafka processing. - * - * @param topicPartition The (topic, partition) pair this data reader is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -class KafkaContinuousDataReader( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) - - private var nextKafkaOffset = startOffset - private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ - - override def next(): Boolean = { - var r: ConsumerRecord[Array[Byte], Array[Byte]] = null - while (r == null) { - r = consumer.get( - nextKafkaOffset, - untilOffset = Long.MaxValue, - pollTimeoutMs = Long.MaxValue, - failOnDataLoss) - } - nextKafkaOffset = r.offset + 1 - currentRecord = r - true - } - - override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow - } - - override def getOffset(): KafkaSourcePartitionOffset = { - KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) - } - - override def close(): Unit = { - consumer.close() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala deleted file mode 100644 index 9843f469c5b25..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} -import scala.collection.JavaConverters._ - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} -import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} - -/** - * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we - * don't need to really send one. - */ -case object KafkaWriterCommitMessage extends WriterCommitMessage - -/** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. - * @param topic The topic this writer is responsible for. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -class KafkaContinuousWriter( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { - - validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} -} - -/** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. - * @param topic The topic that should be written to. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -case class KafkaContinuousWriterFactory( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) - } -} - -/** - * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to - * process incoming rows. - * - * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred - * from a `topic` field in the incoming data. - * @param producerParams Parameters to use for the Kafka producer. - * @param inputSchema The attributes in the input data. - */ -class KafkaContinuousDataWriter( - targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) - extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - import scala.collection.JavaConverters._ - - private lazy val producer = CachedKafkaProducer.getOrCreate( - new java.util.HashMap[String, Object](producerParams.asJava)) - - def write(row: InternalRow): Unit = { - checkForErrors() - sendRow(row, producer) - } - - def commit(): WriterCommitMessage = { - // Send is asynchronous, but we can't commit until all rows are actually in Kafka. - // This requires flushing and then checking that no callbacks produced errors. - // We also check for errors before to fail as soon as possible - the check is cheap. - checkForErrors() - producer.flush() - checkForErrors() - KafkaWriterCommitMessage - } - - def abort(): Unit = {} - - def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) - } - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,14 +117,10 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. - * - * @param partitionOffsets the specific offsets to resolve - * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long], - reportDataLoss: String => Unit): KafkaSourceOffset = { - val fetched = runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = + runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -149,19 +145,6 @@ private[kafka010] class KafkaOffsetReader( } } - partitionOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (fetched(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(fetched) - } - /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 27da76068a66f..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,6 +138,21 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } + private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { + val result = kafkaReader.fetchSpecificOffsets(specificOffsets) + specificOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(result) + } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,22 +20,17 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } -private[kafka010] -case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) - extends PartitionOffset - /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -27,12 +27,9 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,8 +43,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport - with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -106,43 +101,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def createContinuousReader( - schema: Optional[StructType], - metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaContinuousReader( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) - } - /** * Returns a new base relation with the given parameters. * @@ -223,22 +181,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - import scala.collection.JavaConverters._ - - val spark = SparkSession.getActiveSession.get - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - KafkaWriter.validateQuery( - schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -488,27 +450,4 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } - - private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } - - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..6fd333e2f43ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { + topic: Option[String]) { // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -44,7 +46,23 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) } } @@ -56,49 +74,8 @@ private[kafka010] class KafkaWriteTask( producer = null } } -} - -private[kafka010] abstract class KafkaRowWriter( - inputSchema: Seq[Attribute], topic: Option[String]) { - - // used to synchronize with Kafka callbacks - @volatile protected var failedWrite: Exception = _ - protected val projection = createProjection - - private val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - /** - * Send the specified row to the producer, with a callback that will save any exception - * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before - * assuming the row is in Kafka. - */ - protected def sendRow( - row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { - val projectedRow = projection(row) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - producer.send(record, callback) - } - - protected def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } - - private def createProjection = { + private def createProjection: UnsafeProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -135,5 +112,11 @@ private[kafka010] abstract class KafkaRowWriter( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } + + private def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..5e9ae35b3f008 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,9 +43,10 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - schema: Seq[Attribute], + queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -83,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(schema, kafkaParameters, topic) + validateQuery(queryExecution, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala deleted file mode 100644 index dfc97b1c38bb5..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ /dev/null @@ -1,474 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{BinaryType, DataType} -import org.apache.spark.util.Utils - -/** - * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. - * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have - * to duplicate all the code. - */ -class KafkaContinuousSinkSuite extends KafkaContinuousTest { - import testImplicits._ - - override val streamingTimeout = 30.seconds - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - } - super.afterAll() - } - - test("streaming - write to kafka with topic field") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = None, - withOutputMode = Some(OutputMode.Append))( - withSelectExpr = s"'$topic' as topic", "value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - write w/o topic field, with topic option") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))() - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - topic field and topic option") { - /* The purpose of this test is to ensure that the topic option - * overrides the topic field. We begin by writing some data that - * includes a topic field and value (e.g., 'foo') along with a topic - * option. Then when we read from the topic specified in the option - * we should see the data i.e., the data was written to the topic - * option, and not to the topic in the data e.g., foo - */ - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))( - withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("null topic attribute") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "CAST(null as STRING) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getCause.getCause.getMessage - .toLowerCase(Locale.ROOT) - .contains("null topic present in the data.")) - } - - test("streaming - write data with bad schema") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - try { - /* No value field */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) - } - - test("streaming - write data with valid schema but wrong types") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - .selectExpr("CAST(value as STRING) value") - val topic = newTopic() - testUtils.createTopic(topic) - - var writer: StreamingQuery = null - var ex: Exception = null - try { - /* topic field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - try { - /* value field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) - - try { - ex = intercept[StreamingQueryException] { - /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) - } - - test("streaming - write to non-existing topic") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - } - throw writer.exception.get - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) - } - - test("streaming - exception on config serializer") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - testUtils.sendMessages(inputTopic, Array("0")) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() - } - - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() - } - } - - test("generic - write big data with small producer buffer") { - /* This test ensures that we understand the semantics of Kafka when - * is comes to blocking on a call to send when the send buffer is full. - * This test will configure the smallest possible producer buffer and - * indicate that we should block when it is full. Thus, no exception should - * be thrown in the case of a full buffer. - */ - val topic = newTopic() - testUtils.createTopic(topic, 1) - val options = new java.util.HashMap[String, String] - options.put("bootstrap.servers", testUtils.brokerAddress) - options.put("buffer.memory", "16384") // min buffer size - options.put("block.on.buffer.full", "true") - options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - val inputSchema = Seq(AttributeReference("value", BinaryType)()) - val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) - try { - val fieldTypes: Array[DataType] = Array(BinaryType) - val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) - row.update(0, data) - val iter = Seq.fill(1000)(converter.apply(row)).iterator - iter.foreach(writeTask.write(_)) - writeTask.commit() - } finally { - writeTask.close() - } - } - - private def createKafkaReader(topic: String): DataFrame = { - spark.read - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("startingOffsets", "earliest") - .option("endingOffsets", "latest") - .option("subscribe", topic) - .load() - } - - private def createKafkaWriter( - input: DataFrame, - withTopic: Option[String] = None, - withOutputMode: Option[OutputMode] = None, - withOptions: Map[String, String] = Map[String, String]()) - (withSelectExpr: String*): StreamingQuery = { - var stream: DataStreamWriter[Row] = null - val checkpointDir = Utils.createTempDir() - var df = input.toDF() - if (withSelectExpr.length > 0) { - df = df.selectExpr(withSelectExpr: _*) - } - stream = df.writeStream - .format("kafka") - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - // We need to reduce blocking time to efficiently test non-existent partition behavior. - .option("kafka.max.block.ms", "1000") - .trigger(Trigger.Continuous(1000)) - .queryName("kafkaStream") - withTopic.foreach(stream.option("topic", _)) - withOutputMode.foreach(stream.outputMode(_)) - withOptions.foreach(opt => stream.option(opt._1, opt._2)) - stream.start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala deleted file mode 100644 index b3dade414f625..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} - -// Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest - -class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { - import testImplicits._ - - override val brokerProps = Map("auto.create.topics.enable" -> "false") - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Execute { query => - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists { r => - // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) - }, - s"query never reconfigured to new topic $topic2") - } - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } -} - -class KafkaContinuousSourceStressForDontFailOnDataLossSuite - extends KafkaSourceStressForDontFailOnDataLossSuite { - override protected def startStream(ds: Dataset[Int]) = { - ds.writeStream - .format("memory") - .queryName("memory") - .start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala deleted file mode 100644 index e713e6695d2bd..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.spark.SparkContext -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.sql.test.TestSparkSession - -// Trait to configure StreamTest for kafka continuous execution tests. -trait KafkaContinuousTest extends KafkaSourceTest { - override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true - - // We need more than the default local[2] to be able to schedule all partitions simultaneously. - override protected def createSparkSession = new TestSparkSession( - new SparkContext( - "local[10]", - "continuous-stream-test-sql-context", - sparkConf.set("spark.sql.testkey", "true"))) - - // In addition to setting the partitions in Kafka, we have to wait until the query has - // reconfigured to the new count so the test framework can hook in properly. - override protected def setTopicPartitions( - topic: String, newCount: Int, query: StreamExecution) = { - testUtils.addPartitions(topic, newCount) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists(_.knownPartitions.size == newCount), - s"query never reconfigured to $newCount partitions") - } - } - - test("ensure continuous stream is being used") { - val query = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "1") - .load() - - testStream(query)( - Execute(q => assert(q.isInstanceOf[ContinuousExecution])) - ) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index d66908f86ccc7..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,14 +34,11 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -52,11 +49,9 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds - protected val brokerProps = Map[String, Object]() - override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils(brokerProps) + testUtils = new KafkaTestUtils testUtils.setup() } @@ -64,25 +59,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null + super.afterAll() } - super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, // we don't know which data should be fetched when `startingOffsets` is latest. - q match { - case c: ContinuousExecution => c.awaitEpoch(0) - case m: MicroBatchExecution => m.processAllAvailable() - } + q.processAllAvailable() true } - protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { - testUtils.addPartitions(topic, newCount) - } - /** * Add data to Kafka. * @@ -94,7 +82,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -109,18 +97,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -151,158 +137,14 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } - - private val topicId = new AtomicInteger(0) - protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { - - import testImplicits._ - - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) - .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } - - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) - } - - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") +class KafkaSourceSuite extends KafkaSourceTest { - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) + import testImplicits._ - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } + private val topicId = new AtomicInteger(0) testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -395,51 +237,86 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() + test("(de)serialization of initial offsets") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) + testUtils.createTopic(topic, partitions = 64) - val kafka = spark + val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) } -} -class KafkaSourceSuiteBase extends KafkaSourceTest { + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) - import testImplicits._ + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } test("cannot stop Kafka stream") { val topic = newTopic() @@ -451,7 +328,7 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topic.*") + .option("subscribePattern", s"topic-.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -545,6 +422,65 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -604,6 +540,34 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -665,6 +629,8 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -710,10 +676,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, - Execute { q => - // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) - }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -744,7 +706,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") - .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -762,6 +723,47 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { query.stop() } + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -798,7 +800,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -839,7 +843,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -971,23 +977,6 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - test("stress test for failOnDataLoss=false") { val reader = spark .readStream @@ -1001,7 +990,20 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..e8d683a578f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,9 +191,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -211,30 +208,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => null // fall back to v1 + case _ => + throw new AnalysisException(s"$cls does not support data reading.") } - if (reader == null) { - loadV1Source(paths: _*) - } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) - } + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } else { - loadV1Source(paths: _*) + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) } } - private def loadV1Source(paths: String*) = { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } - /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..3304f368e1050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,24 +255,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() + case _ => throw new AnalysisException(s"$cls does not support data writing.") } } else { - saveToV1Source() - } - } - - private def saveToV1Source(): Unit = { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..f0bdf84bb7a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,11 +81,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -419,17 +418,11 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - if (sources == null) { - // sources might not be initialized yet - false - } else { - val source = sources(sourceIndex) - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset - } + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { @@ -443,7 +436,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") + logDebug(s"Unblocked at $newOffset for $source") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index e700aa4f9aea7..d79e4bd65f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,6 +77,7 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { + reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -200,8 +201,6 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. - } finally { - reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.UUID import java.util.concurrent.TimeUnit -import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -54,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -80,17 +78,15 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - val stateUpdate = new UnaryOperator[State] { - override def apply(s: State) = s match { - // If we ended the query to reconfigure, reset the state to active. - case RECONFIGURING => ACTIVE - case _ => s - } - } - do { - runContinuous(sparkSessionForStream) - } while (state.updateAndGet(stateUpdate) == ACTIVE) + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) } /** @@ -124,16 +120,12 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -149,7 +141,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -234,11 +225,13 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -266,7 +259,6 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -281,22 +273,17 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - val oldOffset = synchronized { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - offsetLog.get(epoch - 1) + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. } - - // If offset hasn't changed since last epoch, there's been no new data. - if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { - noNewData = true - } - - awaitProgressLock.lock() - try { - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..98017c3ac6a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,15 +39,6 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage -/** - * The RpcEndpoint stop() will wait to clear out the message queue before terminating the - * object. This can lead to a race condition where the query restarts at epoch n, a new - * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. - * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous - * message to stop any writes to the ContinuousExecution object. - */ -private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage - // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -125,8 +116,6 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var queryWritesStopped: Boolean = false - private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -158,16 +147,12 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + partitionCommits.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { - // If we just drop these messages, we won't do any writes to the query. The lame duck tasks - // won't shed errors or anything. - case _ if queryWritesStopped => () - case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -203,9 +188,5 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) - - case StopContinuousExecutionWrites => - queryWritesStopped = true - context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,29 +279,18 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) - } - + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - sink, + dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,8 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -81,9 +80,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } - protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false - /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -193,7 +189,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = defaultTrigger, + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -280,7 +276,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -407,11 +403,18 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(indexToSource(sourceIndex), offset) } } @@ -470,12 +473,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } - case _ => - } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -603,10 +600,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { - case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r - } + .collect { case StreamingExecutionRelation(s, _) => s } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -619,13 +613,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } }.getOrElse { throw new IllegalArgumentException( - "Could not find index of the source to which data was added") + "Could find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From cd9f49a2aed3799964976ead06080a0f7044a0c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 13 Jan 2018 16:13:44 +0900 Subject: [PATCH 635/712] [SPARK-22980][PYTHON][SQL] Clarify the length of each series is of each batch within scalar Pandas UDF ## What changes were proposed in this pull request? This PR proposes to add a note that saying the length of a scalar Pandas UDF's `Series` is not of the whole input column but of the batch. We are fine for a group map UDF because the usage is different from our typical UDF but scalar UDFs might cause confusion with the normal UDF. For example, please consider this example: ```python from pyspark.sql.functions import pandas_udf, col, lit df = spark.range(1) f = pandas_udf(lambda x, y: len(x) + y, LongType()) df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 1| +------------------+ ``` ```python from pyspark.sql.functions import udf, col, lit df = spark.range(1) f = udf(lambda x, y: len(x) + y, "long") df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 4| +------------------+ ``` ## How was this patch tested? Manually built the doc and checked the output. Author: hyukjinkwon Closes #20237 from HyukjinKwon/SPARK-22980. --- python/pyspark/sql/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 733e32bd825b0..e1ad6590554cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2184,6 +2184,11 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 8| JOHN DOE| 22| +----------+--------------+------------+ + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + 2. GROUP_MAP A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` From 628a1ca5a4d14397a90e9e96a7e03e8f63531b20 Mon Sep 17 00:00:00 2001 From: shimamoto Date: Sat, 13 Jan 2018 09:40:00 -0600 Subject: [PATCH 636/712] [SPARK-23043][BUILD] Upgrade json4s to 3.5.3 ## What changes were proposed in this pull request? Spark still use a few years old version 3.2.11. This change is to upgrade json4s to 3.5.3. Note that this change does not include the Jackson update because the Jackson version referenced in json4s 3.5.3 is 2.8.4, which has a security vulnerability ([see](https://issues.apache.org/jira/browse/SPARK-20433)). ## How was this patch tested? Existing unit tests and build. Author: shimamoto Closes #20233 from shimamoto/upgrade-json4s. --- .../deploy/history/HistoryServerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 19 ++++++++++--------- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 13 +++++++------ 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3738f85da5831..87778dda0e2c8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -486,7 +486,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers json match { case JNothing => Seq() case apps: JArray => - apps.filter(app => { + apps.children.filter(app => { (app \ "attempts") match { case attempts: JArray => val state = (attempts.children.head \ "completed").asInstanceOf[JBool] diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 326546787ab6c..ed51fc445fdfb 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -131,7 +131,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val storageJson = getJson(ui, "storage/rdd") storageJson.children.length should be (1) - (storageJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) + (storageJson.children.head \ "storageLevel").extract[String] should be ( + StorageLevels.DISK_ONLY.description) val rddJson = getJson(ui, "storage/rdd/0") (rddJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) @@ -150,7 +151,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedStorageJson = getJson(ui, "storage/rdd") updatedStorageJson.children.length should be (1) - (updatedStorageJson \ "storageLevel").extract[String] should be ( + (updatedStorageJson.children.head \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( @@ -204,7 +205,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } val stageJson = getJson(sc.ui.get, "stages") stageJson.children.length should be (1) - (stageJson \ "status").extract[String] should be (StageStatus.FAILED.name()) + (stageJson.children.head \ "status").extract[String] should be (StageStatus.FAILED.name()) // Regression test for SPARK-2105 class NotSerializable @@ -325,11 +326,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") - (jobJson \ "numTasks").extract[Int]should be (2) - (jobJson \ "numCompletedTasks").extract[Int] should be (3) - (jobJson \ "numFailedTasks").extract[Int] should be (1) - (jobJson \ "numCompletedStages").extract[Int] should be (2) - (jobJson \ "numFailedStages").extract[Int] should be (1) + (jobJson \\ "numTasks").extract[Int]should be (2) + (jobJson \\ "numCompletedTasks").extract[Int] should be (3) + (jobJson \\ "numFailedTasks").extract[Int] should be (1) + (jobJson \\ "numCompletedStages").extract[Int] should be (2) + (jobJson \\ "numFailedStages").extract[Int] should be (1) val stageJson = getJson(sc.ui.get, "stages") for { @@ -656,7 +657,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) - val attempts = (appListJsonAst \ "attempts").children + val attempts = (appListJsonAst.children.head \ "attempts").children attempts.size should be (1) (attempts(0) \ "completed").extract[Boolean] should be (false) parseDate(attempts(0) \ "startTime") should be (sc.startTime) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a7fce2ede0ea5..2a298769be44c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -167,7 +168,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 94b2e98d85e74..abee326f283ab 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -168,7 +169,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/pom.xml b/pom.xml index d14594aa4ccb0..666d5d7169a15 100644 --- a/pom.xml +++ b/pom.xml @@ -705,7 +705,13 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.11 + 3.5.3 + + + com.fasterxml.jackson.core + * + + org.scala-lang @@ -732,11 +738,6 @@ scala-parser-combinators_${scala.binary.version} 1.0.4 - - org.scala-lang - scalap - ${scala.version} - jline From fc6fe8a1d0f161c4788f3db94de49a8669ba3bcc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 13 Jan 2018 10:01:44 -0600 Subject: [PATCH 637/712] [SPARK-22870][CORE] Dynamic allocation should allow 0 idle time ## What changes were proposed in this pull request? This pr to make `0` as a valid value for `spark.dynamicAllocation.executorIdleTimeout`. For details, see the jira description: https://issues.apache.org/jira/browse/SPARK-22870. ## How was this patch tested? N/A Author: Yuming Wang Author: Yuming Wang Closes #20080 from wangyum/SPARK-22870. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 2e00dc8b49dd5..6c59038f2a6c1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -195,8 +195,11 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeoutS <= 0) { - throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + if (executorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be >= 0!") + } + if (cachedExecutorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!") } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors From bd4a21b4820c4ebaf750131574a6b2eeea36907e Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 14 Jan 2018 02:28:57 +0800 Subject: [PATCH 638/712] [SPARK-23036][SQL][TEST] Add withGlobalTempView for testing ## What changes were proposed in this pull request? Add withGlobalTempView when create global temp view, like withTempView and withView. And correct some improper usage. Please see jira. There are other similar place like that. I will fix it if community need. Please confirm it. ## How was this patch tested? no new test. Author: xubo245 <601450868@qq.com> Closes #20228 from xubo245/DropTempView. --- .../sql/execution/GlobalTempViewSuite.scala | 55 ++++++++----------- .../spark/sql/execution/SQLViewSuite.scala | 34 +++++++----- .../apache/spark/sql/test/SQLTestUtils.scala | 21 +++++-- 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index cc943e0356f2a..dcc6fa6403f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -36,7 +36,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("basic semantic") { val expectedErrorMsg = "not found" - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, @@ -79,19 +79,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) - } finally { - spark.catalog.dropGlobalTempView("src") } } test("global temp view is shared among all sessions") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) val newSession = spark.newSession() checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) - } finally { - spark.catalog.dropGlobalTempView("src") } } @@ -105,27 +101,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE GLOBAL TEMP VIEW USING") { withTempPath { path => - try { + withGlobalTempView("src") { Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.toURI}')") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) - } finally { - spark.catalog.dropGlobalTempView("src") } } } test("CREATE TABLE LIKE should work for global temp view") { - try { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") - val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) - assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) - } finally { - spark.catalog.dropGlobalTempView("src") - sql("DROP TABLE default.cloned") + withTable("cloned") { + withGlobalTempView("src") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType() + .add("a", "int", false).add("b", "string", false)) + } } } @@ -146,26 +140,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } test("should lookup global temp view if and only if global temp db is specified") { - try { - sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") - sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + withTempView("same_name") { + withGlobalTempView("same_name") { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") - checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) - // we never lookup global temp views if database is not specified in table name - spark.catalog.dropTempView("same_name") - intercept[AnalysisException](sql("SELECT * FROM same_name")) + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) - // Use qualified name to lookup a global temp view. - checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) - } finally { - spark.catalog.dropTempView("same_name") - spark.catalog.dropGlobalTempView("same_name") + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } } } test("public Catalog should recognize global temp view") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") assert(spark.catalog.tableExists(globalTempDB, "src")) @@ -175,8 +168,6 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { description = null, tableType = "TEMPORARY", isTemporary = true).toString) - } finally { - spark.catalog.dropGlobalTempView("src") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 08a4a21b20f61..8c55758cfe38d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -69,21 +69,25 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a permanent view on a temp view") { - withView("jtv1", "temp_jtv1", "global_temp_jtv1") { - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") - var e = intercept[AnalysisException] { - sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `jtv1` by " + - "referencing a temporary view `temp_jtv1`")) - - val globalTempDB = spark.sharedState.globalTempViewManager.database - sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") - e = intercept[AnalysisException] { - sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + - s"a temporary view `global_temp`.`global_temp_jtv1`")) + withView("jtv1") { + withTempView("temp_jtv1") { + withGlobalTempView("global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) + + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 904f9f2ad0b22..bc4a120f7042f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -254,13 +254,26 @@ private[sql] trait SQLTestUtilsBase } /** - * Drops temporary table `tableName` after calling `f`. + * Drops temporary view `viewNames` after calling `f`. */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(viewNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { + // temp views that never got created. + try viewNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops global temporary view `viewNames` after calling `f`. + */ + protected def withGlobalTempView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // global temp views that never got created. + try viewNames.foreach(spark.catalog.dropGlobalTempView) catch { case _: NoSuchTableException => } } From ba891ec993c616dc4249fc786c56ea82ed04a827 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 14 Jan 2018 02:36:32 +0800 Subject: [PATCH 639/712] [SPARK-22790][SQL] add a configurable factor to describe HadoopFsRelation's size ## What changes were proposed in this pull request? as per discussion in https://github.com/apache/spark/pull/19864#discussion_r156847927 the current HadoopFsRelation is purely based on the underlying file size which is not accurate and makes the execution vulnerable to errors like OOM Users can enable CBO with the functionalities in https://github.com/apache/spark/pull/19864 to avoid this issue This JIRA proposes to add a configurable factor to sizeInBytes method in HadoopFsRelation class so that users can mitigate this problem without CBO ## How was this patch tested? Existing tests Author: CodingCat Author: Nan Zhu Closes #20072 from CodingCat/SPARK-22790. --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++++- .../datasources/HadoopFsRelation.scala | 6 ++- .../datasources/HadoopFsRelationSuite.scala | 41 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 36e802a9faa6f..6746fbcaf2483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -249,7 +249,7 @@ object SQLConf { val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") .internal() .doc("When true, the query optimizer will infer and propagate data constraints in the query " + - "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + "for certain kinds of query plans (such as those with a large number of predicates and " + "aliases) which might negatively impact overall runtime.") .booleanConf @@ -263,6 +263,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .doubleConf + .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .createWithDefault(1.0) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -1255,6 +1264,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 89d8a85a9cbd2..6b34638529770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -82,7 +82,11 @@ case class HadoopFsRelation( } } - override def sizeInBytes: Long = location.sizeInBytes + override def sizeInBytes: Long = { + val compressionFactor = sqlContext.conf.fileCompressionFactor + (location.sizeInBytes * compressionFactor).toLong + } + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index caf03885e3873..c1f2c18d1417d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.{File, FilenameFilter} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.test.SharedSQLContext class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { @@ -39,4 +40,44 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } + + test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") { + import testImplicits._ + Seq(1.0, 0.5).foreach { compressionFactor => + withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, + "spark.sql.autoBroadcastJoinThreshold" -> "400") { + withTempPath { workDir => + // the file size is 740 bytes + val workDirPath = workDir.getAbsolutePath + val data1 = Seq(100, 200, 300, 400).toDF("count") + data1.write.parquet(workDirPath + "/data1") + val df1FromFile = spark.read.parquet(workDirPath + "/data1") + val data2 = Seq(100, 200, 300, 400).toDF("count") + data2.write.parquet(workDirPath + "/data2") + val df2FromFile = spark.read.parquet(workDirPath + "/data2") + val joinedDF = df1FromFile.join(df2FromFile, Seq("count")) + if (compressionFactor == 0.5) { + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.nonEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.isEmpty) + } else { + // compressionFactor is 1.0 + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.isEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } + } + } + } + } } From 0066d6f6fa604817468471832968d4339f71c5cb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 05:39:38 +0800 Subject: [PATCH 640/712] [SPARK-21213][SQL][FOLLOWUP] Use compatible types for comparisons in compareAndGetNewStats ## What changes were proposed in this pull request? This pr fixed code to compare values in `compareAndGetNewStats`. The test below fails in the current master; ``` val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) val newStats5 = CommandUtils.compareAndGetNewStats( Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) assert(newStats5.isEmpty) ``` ## How was this patch tested? Added some tests in `CommandUtilsSuite`. Author: Takeshi Yamamuro Closes #20245 from maropu/SPARK-21213-FOLLOWUP. --- .../sql/execution/command/CommandUtils.scala | 4 +- .../execution/command/CommandUtilsSuite.scala | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 1a0d67fc71fbc..c27048626c8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -116,8 +116,8 @@ object CommandUtils extends Logging { oldStats: Option[CatalogStatistics], newTotalSize: BigInt, newRowCount: Option[BigInt]): Option[CatalogStatistics] = { - val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1)) + val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1)) var newStats: Option[CatalogStatistics] = None if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala new file mode 100644 index 0000000000000..f3e15189a6418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics + +class CommandUtilsSuite extends SparkFunSuite { + + test("Check if compareAndGetNewStats returns correct results") { + val oldStats1 = CatalogStatistics(sizeInBytes = 10, rowCount = Some(100)) + val newStats1 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 10, newRowCount = Some(100)) + assert(newStats1.isEmpty) + val newStats2 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = None) + assert(newStats2.isEmpty) + val newStats3 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 20, newRowCount = Some(-1)) + assert(newStats3.isDefined) + newStats3.foreach { stat => + assert(stat.sizeInBytes === 20) + assert(stat.rowCount.isEmpty) + } + val newStats4 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = Some(200)) + assert(newStats4.isDefined) + newStats4.foreach { stat => + assert(stat.sizeInBytes === 10) + assert(stat.rowCount.isDefined && stat.rowCount.get === 200) + } + } + + test("Check if compareAndGetNewStats can handle large values") { + // Tests for large values + val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) + val newStats5 = CommandUtils.compareAndGetNewStats( + Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) + assert(newStats5.isEmpty) + } +} From afae8f2bc82597593595af68d1aa2d802210ea8b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 14 Jan 2018 11:26:49 +0900 Subject: [PATCH 641/712] [SPARK-22959][PYTHON] Configuration to select the modules for daemon and worker in PySpark ## What changes were proposed in this pull request? We are now forced to use `pyspark/daemon.py` and `pyspark/worker.py` in PySpark. This doesn't allow a custom modification for it (well, maybe we can still do this in a super hacky way though, for example, setting Python executable that has the custom modification). Because of this, for example, it's sometimes hard to debug what happens inside Python worker processes. This is actually related with [SPARK-7721](https://issues.apache.org/jira/browse/SPARK-7721) too as somehow Coverage is unable to detect the coverage from `os.fork`. If we have some custom fixes to force the coverage, it works fine. This is also related with [SPARK-20368](https://issues.apache.org/jira/browse/SPARK-20368). This JIRA describes Sentry support which (roughly) needs some changes within worker side. With this configuration advanced users will be able to do a lot of pluggable workarounds and we can meet such potential needs in the future. As an example, let's say if I configure the module `coverage_daemon` and had `coverage_daemon.py` in the python path: ```python import os from pyspark import daemon if "COVERAGE_PROCESS_START" in os.environ: from pyspark.worker import main def _cov_wrapped(*args, **kwargs): import coverage cov = coverage.coverage( config_file=os.environ["COVERAGE_PROCESS_START"]) cov.start() try: main(*args, **kwargs) finally: cov.stop() cov.save() daemon.worker_main = _cov_wrapped if __name__ == '__main__': daemon.manager() ``` I can track the coverages in worker side too. More importantly, we can leave the main code intact but allow some workarounds. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20151 from HyukjinKwon/configuration-daemon-worker. --- .../api/python/PythonWorkerFactory.scala | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index f53c6178047f5..30976ac752a8a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -34,10 +34,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String import PythonWorkerFactory._ - // Because forking processes from Java is expensive, we prefer to launch a single Python daemon - // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently - // only works on UNIX-based systems now because it uses signals for child management, so we can - // also fall back to launching workers (pyspark/worker.py) directly. + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, + // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon + // currently only works on UNIX-based systems now because it uses signals for child management, + // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. val useDaemon = { val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) @@ -45,6 +45,28 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled } + // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are + // for very advanced users and they are experimental. This should be considered + // as expert-only option, and shouldn't be used before knowing what it means exactly. + + // This configuration indicates the module to run the daemon to execute its Python workers. + val daemonModule = SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + logInfo( + s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + "using this to start the daemon up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is enabled and the platform is not Windows.") + value + }.getOrElse("pyspark.daemon") + + // This configuration indicates the module to run each Python worker. + val workerModule = SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + logInfo( + s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + "using this to start the worker up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is disabled or the platform is Windows.") + value + }.getOrElse("pyspark.worker") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -74,8 +96,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself - * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + * Connect to a worker launched through pyspark/daemon.py (by default), which forks python + * processes itself to avoid the high cost of forking from Java. This currently only works + * on UNIX-based systems. */ private def createThroughDaemon(): Socket = { @@ -108,7 +131,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Launch a worker by executing worker.py directly and telling it to connect to us. + * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ private def createSimpleWorker(): Socket = { var serverSocket: ServerSocket = null @@ -116,7 +139,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -159,7 +182,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) From c3548d11c3c57e8f2c6ebd9d2d6a3924ddcd3cba Mon Sep 17 00:00:00 2001 From: foxish Date: Sat, 13 Jan 2018 21:34:28 -0800 Subject: [PATCH 642/712] [SPARK-23063][K8S] K8s changes for publishing scripts (and a couple of other misses) ## What changes were proposed in this pull request? Including the `-Pkubernetes` flag in a few places it was missed. ## How was this patch tested? checkstyle, mima through manual tests. Author: foxish Closes #20256 from foxish/SPARK-23063. --- dev/create-release/release-build.sh | 4 ++-- dev/create-release/releaseutils.py | 2 ++ dev/deps/spark-deps-hadoop-2.6 | 11 +++++++++++ dev/deps/spark-deps-hadoop-2.7 | 11 +++++++++++ dev/lint-java | 2 +- dev/mima | 2 +- dev/scalastyle | 1 + dev/test-dependencies.sh | 2 +- 8 files changed, 30 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c71137468054f..a3579f21fc539 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 730138195e5fe..32f6cbb29f0be 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -185,6 +185,8 @@ def get_commits(tag): "graphx": "GraphX", "input/output": CORE_COMPONENT, "java api": "Java API", + "k8s": "Kubernetes", + "kubernetes": "Kubernetes", "mesos": "Mesos", "ml": "MLlib", "mllib": "MLlib", diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2a298769be44c..48e54568e6fc6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -131,10 +135,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -147,6 +154,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -171,6 +180,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -186,5 +196,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index abee326f283ab..1807a77900e52 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -132,10 +136,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -148,6 +155,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -172,6 +181,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -187,5 +197,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/lint-java b/dev/lint-java index c2e80538ef2a5..1f0b0c8379ed0 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..cd2694ff4d3de 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..b8053df05fa2b 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \ -Pkinesis-asl \ -Pmesos \ -Pkafka-0-8 \ + -Pkubernetes \ -Pyarn \ -Pflume \ -Phive \ diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..3bf7618e1ea96 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 From 7a3d0aad2b89aef54f7dd580397302e9ff984d9d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 13 Jan 2018 23:26:12 -0800 Subject: [PATCH 643/712] [SPARK-23038][TEST] Update docker/spark-test (JDK/OS) ## What changes were proposed in this pull request? This PR aims to update the followings in `docker/spark-test`. - JDK7 -> JDK8 Spark 2.2+ supports JDK8 only. - Ubuntu 12.04.5 LTS(precise) -> Ubuntu 16.04.3 LTS(xeniel) The end of life of `precise` was April 28, 2017. ## How was this patch tested? Manual. * Master ``` $ cd external/docker $ ./build $ export SPARK_HOME=... $ docker run -v $SPARK_HOME:/opt/spark spark-test-master CONTAINER_IP=172.17.0.3 ... 18/01/11 06:50:25 INFO MasterWebUI: Bound MasterWebUI to 172.17.0.3, and started at http://172.17.0.3:8080 18/01/11 06:50:25 INFO Utils: Successfully started service on port 6066. 18/01/11 06:50:25 INFO StandaloneRestServer: Started REST server for submitting applications on port 6066 18/01/11 06:50:25 INFO Master: I have been elected leader! New state: ALIVE ``` * Slave ``` $ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://172.17.0.3:7077 CONTAINER_IP=172.17.0.4 ... 18/01/11 06:51:54 INFO Worker: Successfully registered with master spark://172.17.0.3:7077 ``` After slave starts, master will show ``` 18/01/11 06:51:54 INFO Master: Registering worker 172.17.0.4:8888 with 4 cores, 1024.0 MB RAM ``` Author: Dongjoon Hyun Closes #20230 from dongjoon-hyun/SPARK-23038. --- external/docker/spark-test/base/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index 5a95a9387c310..c70cd71367679 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -15,14 +15,14 @@ # limitations under the License. # -FROM ubuntu:precise +FROM ubuntu:xenial # Upgrade package index -# install a few other useful packages plus Open Jdk 7 +# install a few other useful packages plus Open Jdk 8 # Remove unneeded /var/lib/apt/lists/* after install to reduce the # docker image size (by ~30MB) RUN apt-get update && \ - apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.11.8 From 66738d29c59871b29d26fc3756772b95ef536248 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 14 Jan 2018 19:43:10 +0900 Subject: [PATCH 644/712] [SPARK-23069][DOCS][SPARKR] fix R doc for describe missing text ## What changes were proposed in this pull request? fix doc truncated ## How was this patch tested? manually Author: Felix Cheung Closes #20263 from felixcheung/r23docfix. --- R/pkg/R/DataFrame.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 9956f7eda91e6..6caa125e1e14a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3054,10 +3054,10 @@ setMethod("describe", #' \item stddev #' \item min #' \item max -#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75\%") #' } #' If no statistics are given, this function computes count, mean, stddev, min, -#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' approximate quartiles (percentiles at 25\%, 50\%, and 75\%), and max. #' This function is meant for exploratory data analysis, as we make no guarantee about the #' backward compatibility of the schema of the resulting Dataset. If you want to #' programmatically compute summary statistics, use the \code{agg} function instead. @@ -4019,9 +4019,9 @@ setMethod("broadcast", #' #' Spark will use this watermark for several purposes: #' \itemize{ -#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' \item To know when a given time window aggregation can be finalized and thus can be emitted #' when using output modes that do not allow updates. -#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' \item To minimize the amount of state that we need to keep for on-going aggregations. #' } #' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across #' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost From 990f05c80347c6eec2ee06823cff587c9ea90b49 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 22:26:21 +0800 Subject: [PATCH 645/712] [SPARK-23021][SQL] AnalysisBarrier should override innerChildren to print correct explain output ## What changes were proposed in this pull request? `AnalysisBarrier` in the current master cuts off explain results for parsed logical plans; ``` scala> Seq((1, 1)).toDF("a", "b").groupBy("a").count().sample(0.1).explain(true) == Parsed Logical Plan == Sample 0.0, 0.1, false, -7661439431999668039 +- AnalysisBarrier Aggregate [a#5], [a#5, count(1) AS count#14L] ``` To fix this, `AnalysisBarrier` needs to override `innerChildren` and this pr changed the output to; ``` == Parsed Logical Plan == Sample 0.0, 0.1, false, -5086223488015741426 +- AnalysisBarrier +- Aggregate [a#5], [a#5, count(1) AS count#14L] +- Project [_1#2 AS a#5, _2#3 AS b#6] +- LocalRelation [_1#2, _2#3] ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20247 from maropu/SPARK-23021-2. --- .../plans/logical/basicLogicalOperators.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 95e099c340af1..a4fca790dd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -903,6 +903,7 @@ case class Deduplicate( * This analysis barrier will be removed at the end of analysis stage. */ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override protected def innerChildren: Seq[LogicalPlan] = Seq(child) override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = child.isStreaming override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index dfabf1ec2a22a..a4273de5fe260 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,4 +171,21 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { + val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(true) + } + assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( + s"""== Parsed Logical Plan == + |GlobalLimit 1 + |+- LocalLimit 1 + | +- AnalysisBarrier + | +- Aggregate [a#0], [a#0, count(1) AS count#0L] + | +- Project [_1#0 AS a#0, _2#0 AS b#0] + | +- LocalRelation [_1#0, _2#0] + |""".stripMargin)) + } } From 60eeecd7760aee6ce2fd207c83ae40054eadaf83 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Sun, 14 Jan 2018 08:32:35 -0600 Subject: [PATCH 646/712] [SPARK-23051][CORE] Fix for broken job description in Spark UI ## What changes were proposed in this pull request? In 2.2, Spark UI displayed the stage description if the job description was not set. This functionality was broken, the GUI has shown no description in this case. In addition, the code uses jobName and jobDescription instead of stageName and stageDescription when JobTableRowData is created. In this PR the logic producing values for the job rows was modified to find the latest stage attempt for the job and use that as a fallback if job description was missing. StageName and stageDescription are also set using values from stage and jobName/description is used only as a fallback. ## How was this patch tested? Manual testing of the UI, using the code in the bug report. Author: Sandor Murakozi Closes #20251 from smurakozi/SPARK-23051. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 37e3b3b304a63..ff916bb6a5759 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -65,12 +65,10 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We }.map { job => val jobId = job.jobId val status = job.status - val displayJobDescription = - if (job.description.isEmpty) { - job.name - } else { - UIUtils.makeDescription(job.description.get, "", plainText = true).text - } + val jobDescription = store.lastStageAttempt(job.stageIds.max).description + val displayJobDescription = jobDescription + .map(UIUtils.makeDescription(_, "", plainText = true).text) + .getOrElse("") val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -429,20 +427,23 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), - basePath, plainText = false) + val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) + val lastStageDescription = lastStageAttempt.description.getOrElse("") + + val formattedJobDescription = + UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - jobData.name, - jobData.description.getOrElse(jobData.name), + lastStageAttempt.name, + lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - jobDescription, + formattedJobDescription, detailUrl ) } From 42a1a15d739890bdfbb367ef94198b19e98ffcb7 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 15 Jan 2018 02:02:49 +0800 Subject: [PATCH 647/712] [SPARK-22999][SQL] show databases like command' can remove the like keyword ## What changes were proposed in this pull request? SHOW DATABASES (LIKE pattern = STRING)? Can be like the back increase? When using this command, LIKE keyword can be removed. You can refer to the SHOW TABLES command, SHOW TABLES 'test *' and SHOW TABELS like 'test *' can be used. Similarly SHOW DATABASES 'test *' and SHOW DATABASES like 'test *' can be used. ## How was this patch tested? unit tests manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20194 from guoxiaolongzte/SPARK-22999. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6daf01d98426c..39d5e4ed56628 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 591510c1d8283..2b4b7c137428a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -991,6 +991,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("SHOW DATABASES LIKE '*db1A'"), Row("showdb1a") :: Nil) + checkAnswer( + sql("SHOW DATABASES '*db1A'"), + Row("showdb1a") :: Nil) + checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), Row("showdb1a") :: Nil) From b98ffa4d6dabaf787177d3f14b200fc4b118c7ce Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 10:55:21 +0800 Subject: [PATCH 648/712] [SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String ## What changes were proposed in this pull request? This pr fixed the issue when casting `UserDefinedType`s into strings; ``` >>> from pyspark.ml.classification import MultilayerPerceptronClassifier >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), (1.0, Vectors.dense([0.0, 1.0]))], ["label", "features"]) >>> df.selectExpr("CAST(features AS STRING)").show(truncate = False) +-------------------------------------------+ |features | +-------------------------------------------+ |[6,1,0,0,2800000020,2,0,0,0] | |[6,1,0,0,2800000020,2,0,0,3ff0000000000000]| +-------------------------------------------+ ``` The root cause is that `Cast` handles input data as `UserDefinedType.sqlType`(this is underlying storage type), so we should pass data into `UserDefinedType.deserialize` then `toString`. This pr modified the result into; ``` +---------+ |features | +---------+ |[0.0,0.0]| |[0.0,1.0]| +---------+ ``` ## How was this patch tested? Added tests in `UserDefinedTypeSuite `. Author: Takeshi Yamamuro Closes #20246 from maropu/SPARK-23054. --- .../spark/sql/catalyst/expressions/Cast.scala | 7 +++++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f21aa1e9e3135..a95ebe301b9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case udt: UserDefinedType[_] => + buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case udt: UserDefinedType[_] => + val udtRef = ctx.addReferenceObj("udt", udt) + (c, evPrim, evNull) => { + s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a08433ba794d9..cc8b600efa46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -44,6 +44,8 @@ object UDT { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } + + override def toString: String = data.mkString("(", ", ", ")") } private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest + with ExpressionEvalHelper { import testImplicits._ private lazy val pointsRDD = Seq( @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT pointsRDD.except(pointsRDD2), Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } + + test("SPARK-23054 Cast UserDefinedType to string") { + val udt = new UDT.MyDenseVectorUDT() + val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val data = udt.serialize(vector) + val ret = Cast(Literal(data, udt), StringType, None) + checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") + } } From 9a96bfc8bf021cb4b6c62fac6ce1bcf87affcd43 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 15 Jan 2018 12:06:56 +0800 Subject: [PATCH 649/712] [SPARK-23049][SQL] `spark.sql.files.ignoreCorruptFiles` should work for ORC files ## What changes were proposed in this pull request? When `spark.sql.files.ignoreCorruptFiles=true`, we should ignore corrupted ORC files. ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20240 from dongjoon-hyun/SPARK-23049. --- .../execution/datasources/orc/OrcUtils.scala | 29 ++++++++---- .../datasources/orc/OrcQuerySuite.scala | 47 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 23 +++++++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 8 +++- .../spark/sql/hive/orc/OrcFileOperator.scala | 28 +++++++++-- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 13a23996f4ade..460194ba61c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -50,23 +51,35 @@ object OrcUtils extends Logging { paths } - def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) + try { + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } } } def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e00e057a18cc6..f58c331f33ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -531,6 +532,52 @@ abstract class OrcQueryTest extends OrcTest { val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) assert(df.count() == 20) } + + test("Enabling/disabling ignoreCorruptFiles") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val m1 = intercept[SparkException] { + testIgnoreCorruptFiles() + }.getMessage + assert(m1.contains("Could not read footer for file")) + val m2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m2.contains("Malformed ORC file")) + } + } } class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c8c9ef6e0432..6ad88ed997ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -320,14 +320,27 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) } } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { @@ -335,6 +348,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext testIgnoreCorruptFiles() } assert(exception.getMessage().contains("is not a Parquet file")) + val exception2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + } + assert(exception2.getMessage().contains("is not a Parquet file")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b30289..237ed9bc05988 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -59,9 +59,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles ) } @@ -129,6 +131,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -138,7 +141,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty + val isEmptyFile = + OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty if (isEmptyFile) { Iterator.empty } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c0..80e44ca504356 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.orc +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -46,7 +49,10 @@ private[hive] object OrcFileOperator extends Logging { * create the result reader from that file. If no such file is found, it returns `None`. * @todo Needs to consider all files when schema evolution is taken into account. */ - def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { + def getFileReader(basePath: String, + config: Option[Configuration] = None, + ignoreCorruptFiles: Boolean = false) + : Option[Reader] = { def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = { reader.getObjectInspector match { case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 => @@ -65,16 +71,28 @@ private[hive] object OrcFileOperator extends Logging { } listOrcFiles(basePath, conf).iterator.map { path => - path -> OrcFile.createReader(fs, path) + val reader = try { + Some(OrcFile.createReader(fs, path)) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $path", e) + } + } + path -> reader }.collectFirst { - case (path, reader) if isWithNonEmptySchema(path, reader) => reader + case (path, Some(reader)) if isWithNonEmptySchema(path, reader) => reader } } - def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + def readSchema(paths: Seq[String], conf: Option[Configuration], ignoreCorruptFiles: Boolean) + : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") From b59808385cfe24ce768e5b3098b9034e64b99a5a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 16:26:52 +0800 Subject: [PATCH 650/712] [SPARK-23023][SQL] Cast field data to strings in showString ## What changes were proposed in this pull request? The current `Datset.showString` prints rows thru `RowEncoder` deserializers like; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------------------------------------------+ |a | +------------------------------------------------------------+ |[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]| +------------------------------------------------------------+ ``` This result is incorrect because the correct one is; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------+ |a | +------------------------+ |[[1, 2], [3], [4, 5, 6]]| +------------------------+ ``` So, this pr fixed code in `showString` to cast field data to strings before printing. ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20214 from maropu/SPARK-23023. --- python/pyspark/sql/functions.py | 32 +++++++++---------- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e1ad6590554cf..f7b3f29764040 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1849,14 +1849,14 @@ def explode_outer(col): +---+----------+----+-----+ >>> df.select("id", "a_map", explode_outer("an_array")).show() - +---+-------------+----+ - | id| a_map| col| - +---+-------------+----+ - | 1|Map(x -> 1.0)| foo| - | 1|Map(x -> 1.0)| bar| - | 2| Map()|null| - | 3| null|null| - +---+-------------+----+ + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|[x -> 1.0]| foo| + | 1|[x -> 1.0]| bar| + | 2| []|null| + | 3| null|null| + +---+----------+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.explode_outer(_to_java_column(col)) @@ -1881,14 +1881,14 @@ def posexplode_outer(col): | 3| null|null|null| null| +---+----------+----+----+-----+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show() - +---+-------------+----+----+ - | id| a_map| pos| col| - +---+-------------+----+----+ - | 1|Map(x -> 1.0)| 0| foo| - | 1|Map(x -> 1.0)| 1| bar| - | 2| Map()|null|null| - | 3| null|null|null| - +---+-------------+----+----+ + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|[x -> 1.0]| 0| foo| + | 1|[x -> 1.0]| 1| bar| + | 2| []|null|null| + | 3| null|null|null| + +---+----------+----+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 77e571272920a..34f0ab5aa6699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,13 +237,20 @@ class Dataset[T] private[sql]( private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val takeResult = toDF().take(numRows + 1) + val newDf = toDF() + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + val takeResult = newDf.select(castCols: _*).take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) - // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -252,12 +259,6 @@ class Dataset[T] private[sql]( val str = cell match { case null => "null" case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case d: Date => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case ts: Timestamp => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e4c1a6a484fb..33707080c1301 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1255,6 +1255,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } + test("SPARK-23023 Cast rows to strings in showString") { + val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a") + assert(df1.showString(10) === + s"""+------------+ + || a| + |+------------+ + ||[1, 2, 3, 4]| + |+------------+ + |""".stripMargin) + val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a") + assert(df2.showString(10) === + s"""+----------------+ + || a| + |+----------------+ + ||[1 -> a, 2 -> b]| + |+----------------+ + |""".stripMargin) + val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") + assert(df3.showString(10) === + s"""+------+---+ + || a| b| + |+------+---+ + ||[1, a]| 0| + ||[2, b]| 0| + |+------+---+ + |""".stripMargin) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 54893c184642b..49c59cf695dc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -958,12 +958,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDS() val expected = - """+-------+ - || f| - |+-------+ - ||[foo,1]| - ||[bar,2]| - |+-------+ + """+--------+ + || f| + |+--------+ + ||[foo, 1]| + ||[bar, 2]| + |+--------+ |""".stripMargin checkShowString(ds, expected) From a38c887ac093d7cf343d807515147d87ca931ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 15 Jan 2018 07:49:34 -0600 Subject: [PATCH 651/712] [SPARK-19550][BUILD][FOLLOW-UP] Remove MaxPermSize for sql module ## What changes were proposed in this pull request? Remove `MaxPermSize` for `sql` module ## How was this patch tested? Manually tested. Author: Yuming Wang Closes #20268 from wangyum/SPARK-19550-MaxPermSize. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 839b929abd3cb..7d23637e28342 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -134,7 +134,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 744daa6079779..ef41837f89d68 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -195,7 +195,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} From bd08a9e7af4137bddca638e627ad2ae531bce20f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Jan 2018 22:32:38 +0800 Subject: [PATCH 652/712] [SPARK-23070] Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 ## What changes were proposed in this pull request? Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 and add the missing exclusions to `v23excludes` in `MimaExcludes`. No item can be un-excluded in `v23excludes`. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20264 from gatorsmile/bump22. --- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2ef0e7b40d940..adde213e361f0 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,7 +88,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.0.0" + val previousSparkVersion = "2.2.0" val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 32eb31f495979..d35c50e1d00fe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -102,7 +102,40 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-21728][CORE] Allow SparkSubmit to use Logging + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), + + // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), + + // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), + + // [SPARK-20643][CORE] Add listener implementation to collect app state + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), + + // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), + + // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json + // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), + + // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), + + // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x From 6c81fe227a6233f5d9665d2efadf8a1cf09f700d Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 15 Jan 2018 23:13:15 +0800 Subject: [PATCH 653/712] [SPARK-23035][SQL] Fix improper information of TempTableAlreadyExistsException ## What changes were proposed in this pull request? Problem: it throw TempTableAlreadyExistsException and output "Temporary table '$table' already exists" when we create temp view by using org.apache.spark.sql.catalyst.catalog.GlobalTempViewManager#create, it's improper. So fix improper information about TempTableAlreadyExistsException when create temp view: change "Temporary table" to "Temporary view" ## How was this patch tested? test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") Author: xubo245 <601450868@qq.com> Closes #20227 from xubo245/fixDeprecated. --- .../analysis/AlreadyExistException.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 6 +- .../spark/sql/execution/SQLViewSuite.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 75 ++++++++++++++++++- 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6c..6d587abd8fd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String) extends AnalysisException(s"Table or view '$table' already exists in database '$db'") class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary table '$table' already exists") + extends AnalysisException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 95c87ffa20cb7..6abab0073cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } - test("create temp table") { + test("create temp view") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) val tempTable2 = Range(1, 20, 2, 10) @@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { assert(catalog.getTempView("tbl1") == Option(tempTable1)) assert(catalog.getTempView("tbl2") == Option(tempTable2)) assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists + // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } - // Temporary table already exists but we override it + // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempView("tbl1") == Option(tempTable2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8c55758cfe38d..14082197ba0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -293,7 +293,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") } - assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + assert(e.message.contains("Temporary view") && e.message.contains("already exists")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2b4b7c137428a..6ca21b5aa1595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -835,6 +835,31 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") { + withTempView("view1") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("view1"))) + } + } + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -883,6 +908,42 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") { + withTempView("view1", "view2") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY VIEW view2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO view2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`view2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == + Seq(TableIdentifier("view1"), TableIdentifier("view2"))) + } + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -1728,12 +1789,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("block creating duplicate temp table") { - withView("t_temp") { + withTempView("t_temp") { sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") val e = intercept[TempTableAlreadyExistsException] { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") }.getMessage - assert(e.contains("Temporary table 't_temp' already exists")) + assert(e.contains("Temporary view 't_temp' already exists")) + } + } + + test("block creating duplicate temp view") { + withTempView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary view 't_temp' already exists")) } } From 8ab2d7ea99b2cff8b54b2cb3a1dbf7580845986a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 11:47:42 +0900 Subject: [PATCH 654/712] [SPARK-23080][SQL] Improve error message for built-in functions ## What changes were proposed in this pull request? When a user puts the wrong number of parameters in a function, an AnalysisException is thrown. If the function is a UDF, he user is told how many parameters the function expected and how many he/she put. If the function, instead, is a built-in one, no information about the number of parameters expected and the actual one is provided. This can help in some cases, to debug the errors (eg. bad quotes escaping may lead to a different number of parameters than expected, etc. etc.) The PR adds the information about the number of parameters passed and the expected one, analogously to what happens for UDF. ## How was this patch tested? modified existing UT + manual test Author: Marco Gaido Closes #20271 from mgaido91/SPARK-23080. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 10 +++++++++- .../resources/sql-tests/results/json-functions.sql.out | 4 ++-- .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5ddb39822617d..747016beb06e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,15 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - throw new AnalysisException(s"Invalid number of arguments for function $name") + val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d9dc728a18e8d..581dddc89d0bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 -- !query 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index db37be68e42e6..af6a10b425b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -80,7 +80,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function substr")) + assert(e.getMessage.contains("Invalid number of arguments for function substr. Expected:")) } test("error reporting for incorrect number of arguments - udf") { @@ -89,7 +89,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function foo")) + assert(e.getMessage.contains("Invalid number of arguments for function foo. Expected:")) } test("error reporting for undefined functions") { From c7572b79da0a29e502890d7618eaf805a1c9f474 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 11:20:18 +0800 Subject: [PATCH 655/712] [SPARK-23000] Use fully qualified table names in HiveMetastoreCatalogSuite ## What changes were proposed in this pull request? In another attempt to fix DataSourceWithHiveMetastoreCatalogSuite, this patch uses qualified table names (`default.t`) in the individual tests. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20273 from sameeragarwal/flaky-test. --- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..83b4c862e2546 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("t") { + withTable("default.t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,14 +187,15 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("t") { + withTable("default.t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -203,7 +204,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = @@ -219,8 +220,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === Seq("1.1\t1", "2.1\t2")) } } @@ -228,9 +229,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("t") { + withTable("default.t") { sql( - s"""CREATE TABLE t USING $provider + s"""CREATE TABLE default.t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -248,8 +249,9 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(table("default.t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1\tval_1")) } } } From 07ae39d0ec1f03b1c73259373a8bb599694c7860 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 15 Jan 2018 22:01:14 -0800 Subject: [PATCH 656/712] [SPARK-22956][SS] Bug fix for 2 streams union failover scenario ## What changes were proposed in this pull request? This problem reported by yanlin-Lynn ivoson and LiangchangZ. Thanks! When we union 2 streams from kafka or other sources, while one of them have no continues data coming and in the same time task restart, this will cause an `IllegalStateException`. This mainly cause because the code in [MicroBatchExecution](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala#L190) , while one stream has no continues data, its comittedOffset same with availableOffset during `populateStartOffsets`, and `currentPartitionOffsets` not properly handled in KafkaSource. Also, maybe we should also consider this scenario in other Source. ## How was this patch tested? Add a UT in KafkaSourceSuite.scala Author: Yuanjian Li Closes #20150 from xuanyuanking/SPARK-22956. --- .../spark/sql/kafka010/KafkaSource.scala | 13 ++-- .../spark/sql/kafka010/KafkaSourceSuite.scala | 65 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 6 +- .../sql/execution/streaming/memory.scala | 6 ++ 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..864a92b8f813f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -223,6 +223,14 @@ private[kafka010] class KafkaSource( logInfo(s"GetBatch called with start = $start, end = $end") val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + if (start.isDefined && start.get == end) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) @@ -305,11 +313,6 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - // On recovery, getBatch will get called before getOffset - if (currentPartitionOffsets.isEmpty) { - currentPartitionOffsets = Some(untilPartitionOffsets) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..a0f5695fc485c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -318,6 +318,71 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { + def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, range.map(_.toString).toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + + reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(k => k.toInt) + } + + val df1 = getSpecificDF(0 to 9) + val df2 = getSpecificDF(100 to 199) + + val kafka = df1.union(df2) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(kafka)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((0 to 4) ++ (100 to 104): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((5 to 9) ++ (105 to 109): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smaller topic empty, 5 from bigger one + CheckLastBatch(110 to 114: _*), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(115 to 119: _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(120 to 124: _*) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 42240eeb58d4b..70407f0580f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -208,10 +208,8 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + val start = committedOffsets.get(source) + source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called // here, so we do nothing here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb4..509a69dd922fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -119,9 +119,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val newBlocks = synchronized { val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") batches.slice(sliceStart, sliceEnd) } + if (newBlocks.isEmpty) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) newBlocks From 66217dac4f8952a9923625908ad3dcb030763c81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 15 Jan 2018 22:40:44 -0800 Subject: [PATCH 657/712] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20223 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 49 ++++++++++++------- .../spark/launcher/AbstractAppHandle.java | 22 +++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 ++++--- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 ++++++++++++++--- .../org/apache/spark/launcher/BaseSuite.java | 42 +++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 20 +++----- 8 files changed, 156 insertions(+), 72 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -137,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -146,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..fd6f229b2349c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { + public synchronized void close() throws IOException { if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..660c4443b20b9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -338,16 +365,21 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); + + synchronized (this) { + super.close(); + notifyAll(); + } + if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + handle.dispose(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..75c1af0c71e2a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -197,28 +199,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { From b85eb946ac298e711dad25db0d04eee41d7fd236 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 16 Jan 2018 20:20:33 +0900 Subject: [PATCH 658/712] [SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? Register Vectorized UDFs for SQL Statement. For example, ```Python >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... >>> _ = spark.udf.register("add_one", add_one) >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20171 from gatorsmile/supportVectorizedUDF. --- python/pyspark/sql/catalog.py | 75 ++++++++++++++++++++++++---------- python/pyspark/sql/context.py | 51 ++++++++++++++++------- python/pyspark/sql/tests.py | 76 ++++++++++++++++++++++++++++++----- 3 files changed, 155 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 156603128d063..35fbe9e669adb 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() @@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ # This is to check whether the input function is a wrapped/native UserDefinedFunction if hasattr(f, 'asNondeterministic'): - udf = UserDefinedFunction(f.func, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF, - deterministic=f.deterministic) + if returnType is not None: + raise TypeError( + "Invalid returnType: None is expected when f is a UserDefinedFunction, " + "but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f else: - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - self._jsparkSession.udf().registerPython(name, udf._judf) - return udf._wrapped() + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b8d86cc098e94..85479095af594 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = sqlContext.udf.register("slen", slen) + >>> sqlContext.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP + >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) @@ -575,7 +596,7 @@ class UDFRegistration(object): def __init__(self, sqlContext): self.sqlContext = sqlContext - def register(self, name, f, returnType=StringType()): + def register(self, name, f, returnType=None): return self.sqlContext.registerFunction(name, f, returnType) def registerJavaFunction(self, name, javaClassName, returnType=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..8906618666b14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -380,12 +380,25 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf3(self): - twoargs = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) - self.assertEqual(twoargs.deterministic, True) + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf @@ -402,12 +415,12 @@ def test_nondeterministic_udf2(self): from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) self.assertEqual(random_udf1.deterministic, False) [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf()).collect() self.assertEqual(row[0], 6) # render_doc() reproduces the help() exception without printing output @@ -3691,7 +3704,7 @@ def tearDownClass(cls): ReusedSQLTestCase.tearDownClass() @property - def random_udf(self): + def nondeterministic_vectorized_udf(self): from pyspark.sql.functions import pandas_udf @pandas_udf('double') @@ -3726,6 +3739,21 @@ def test_vectorized_udf_basic(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + def test_vectorized_udf_null_boolean(self): from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] @@ -4085,14 +4113,14 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) - def test_nondeterministic_udf(self): + def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf, pandas_udf, col @pandas_udf('double') def plus_ten(v): return v + 10 - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() @@ -4100,11 +4128,11 @@ def plus_ten(v): self.assertEqual(random_udf.deterministic, False) self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - def test_nondeterministic_udf_in_aggregate(self): + def test_nondeterministic_vectorized_udf_in_aggregate(self): from pyspark.sql.functions import pandas_udf, sum df = self.spark.range(10) - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf with QuietTest(self.sc): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): @@ -4112,6 +4140,23 @@ def test_nondeterministic_udf_in_aggregate(self): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): df.agg(sum(random_udf(df.id))).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): @@ -4147,6 +4192,15 @@ def test_simple(self): expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_register_group_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' + 'SQL_PANDAS_SCALAR_UDF'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) + def test_decorator(self): from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data From 75db14864d2bd9b8e13154226e94d466e3a7e0a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 16 Jan 2018 22:41:30 +0800 Subject: [PATCH 659/712] [SPARK-22392][SQL] data source v2 columnar batch reader ## What changes were proposed in this pull request? a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20153 from cloud-fan/columnar-reader. --- .../sources/v2/reader/DataSourceV2Reader.java | 5 +- .../v2/reader/SupportsScanColumnarBatch.java | 52 ++++++++ .../v2/reader/SupportsScanUnsafeRow.java | 2 +- .../sql/execution/ColumnarBatchScan.scala | 37 +++++- .../sql/execution/DataSourceScanExec.scala | 39 ++---- .../columnar/InMemoryTableScanExec.scala | 101 +++++++++------- .../datasources/v2/DataSourceRDD.scala | 20 ++-- .../datasources/v2/DataSourceV2ScanExec.scala | 72 ++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 112 ++++++++++++++++++ .../execution/WholeStageCodegenSuite.scala | 28 ++--- .../sql/sources/v2/DataSourceV2Suite.scala | 72 ++++++++++- 12 files changed, 400 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 95ee4a8278322..f23c3842bf1b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -38,7 +38,10 @@ * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 0000000000000..27cf3a77724f0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceV2Reader { + @Override + default List> createReadTasks() { + throw new IllegalStateException( + "createReadTasks not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + */ + List> createBatchReadTasks(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createReadTasks()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec880dc85e..2d3ad0eee65ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override default List> createReadTasks() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createReadTasks not supported by default within SupportsScanUnsafeRow"); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5617046e1396e..dd68df9686691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType @@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") @@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c7c06bc..7c7d79c2bbd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -164,13 +164,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -346,33 +348,6 @@ case class FileSourceScanExec( override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 933b9753faa61..3565ee3af1b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -49,9 +49,9 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields relation.schema.fields.forall(f => f.dataType match { @@ -61,6 +61,8 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -90,14 +92,56 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -185,7 +229,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -230,43 +274,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + inputRDD } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5ed4af1..ac104d7cd0cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.v2.reader.ReadTask -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readTasks: java.util.List[ReadTask[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { readTasks.asScala.zipWithIndex.map { @@ -38,10 +38,10 @@ class DataSourceRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506bc560cf..8c64df080242f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,10 +24,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + @transient reader: DataSourceV2Reader) + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def references: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + private lazy val inputRDD: RDD[InternalRow] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + } - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + override protected def needsUnsafeRowConversion: Boolean = false - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..b3f1a1a1aaab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,7 +52,7 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) @@ -132,7 +132,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000000000..44e5146d7c553 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createBatchReadTasks() { + return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + } + } + + static class JavaBatchReadTask implements ReadTask, DataReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..22ca128c27768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1f..a89f7c55bf4f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { @@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class BatchReadTask(start: Int, end: Int) + extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch( + new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} From 12db365b4faf7a185708648d246fc4a2aae0c2c0 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 16 Jan 2018 11:41:08 -0800 Subject: [PATCH 660/712] [SPARK-16139][TEST] Add logging functionality for leaked threads in tests ## What changes were proposed in this pull request? Lots of our tests don't properly shutdown everything they create, and end up leaking lots of threads. For example, `TaskSetManagerSuite` doesn't stop the extra `TaskScheduler` and `DAGScheduler` it creates. There are a couple more instances, eg. in `DAGSchedulerSuite`. This PR adds the possibility to print out the not properly stopped thread list after a test suite executed. The format is the following: ``` ===== FINISHED o.a.s.scheduler.DAGSchedulerSuite: 'task end event should have updated accumulators (SPARK-20342)' ===== ... ===== Global thread whitelist loaded with name /thread_whitelist from classpath: rpc-client.*, rpc-server.*, shuffle-client.*, shuffle-server.*' ===== ScalaTest-run: ===== THREADS NOT STOPPED PROPERLY ===== ScalaTest-run: dag-scheduler-event-loop ScalaTest-run: globalEventExecutor-2-5 ScalaTest-run: ===== END OF THREAD DUMP ===== ScalaTest-run: ===== EITHER PUT THREAD NAME INTO THE WHITELIST FILE OR SHUT IT DOWN PROPERLY ===== ``` With the help of this leaking threads has been identified in TaskSetManagerSuite. My intention is to hunt down and fix such bugs in later PRs. ## How was this patch tested? Manual: TaskSetManagerSuite test executed and found out where are the leaking threads. Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #19893 from gaborgsomogyi/SPARK-16139. --- .../org/apache/spark/SparkFunSuite.scala | 34 +++++++ .../scala/org/apache/spark/ThreadAudit.scala | 99 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 7 +- .../apache/spark/sql/SessionStateSuite.scala | 1 + .../sql/sources/DataSourceAnalysisSuite.scala | 1 + .../spark/sql/test/SharedSQLContext.scala | 23 ++++- .../hive/HiveContextCompatibilitySuite.scala | 1 + .../sql/hive/HiveSessionStateSuite.scala | 1 + .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 + .../sql/hive/client/HiveClientSuite.scala | 1 + .../sql/hive/client/HiveVersionSuite.scala | 1 + .../spark/sql/hive/client/VersionsSuite.scala | 2 + .../hive/execution/HiveComparisonTest.scala | 2 + .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 + .../sql/hive/test/TestHiveSingleton.scala | 1 + 15 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ThreadAudit.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 18077c08c9dcc..3af9d82393bc4 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -27,19 +27,53 @@ import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. + * + * Thread audit happens normally here automatically when a new test suite created. + * The only prerequisite for that is that the test class must extend [[SparkFunSuite]]. + * + * It is possible to override the default thread audit behavior by setting enableAutoThreadAudit + * to false and manually calling the audit methods, if desired. For example: + * + * class MyTestSuite extends SparkFunSuite { + * + * override val enableAutoThreadAudit = false + * + * protected override def beforeAll(): Unit = { + * doThreadPreAudit() + * super.beforeAll() + * } + * + * protected override def afterAll(): Unit = { + * super.afterAll() + * doThreadPostAudit() + * } + * } */ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll + with ThreadAudit with Logging { // scalastyle:on + protected val enableAutoThreadAudit = true + + protected override def beforeAll(): Unit = { + if (enableAutoThreadAudit) { + doThreadPreAudit() + } + super.beforeAll() + } + protected override def afterAll(): Unit = { try { // Avoid leaking map entries in tests that use accumulators without SparkContext AccumulatorContext.clear() } finally { super.afterAll() + if (enableAutoThreadAudit) { + doThreadPostAudit() + } } } diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala new file mode 100644 index 0000000000000..b3cea9de8f304 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +/** + * Thread audit for test suites. + */ +trait ThreadAudit extends Logging { + + val threadWhiteList = Set( + /** + * Netty related internal threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "netty.*", + + /** + * Netty related internal threads. + * A Single-thread singleton EventExecutor inside netty which creates such threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "globalEventExecutor.*", + + /** + * Netty related internal threads. + * Checks if a thread is alive periodically and runs a task when a thread dies. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "threadDeathWatcher.*", + + /** + * During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]] + * creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "rpc-client.*", + "rpc-server.*", + + /** + * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "shuffle-client.*", + "shuffle-server.*" + ) + private var threadNamesSnapshot: Set[String] = Set.empty + + protected def doThreadPreAudit(): Unit = { + threadNamesSnapshot = runningThreadNames() + } + + protected def doThreadPostAudit(): Unit = { + val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s") + + if (threadNamesSnapshot.nonEmpty) { + val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) + .filterNot { s => threadWhiteList.exists(s.matches(_)) } + if (remainingThreadNames.nonEmpty) { + logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") + } + } else { + logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " + + s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n") + } + } + + private def runningThreadNames(): Set[String] = { + Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2ce81ae27daf6..ca6a7e5db3b17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val conf = new SparkConf().set("spark.speculation", "true") sc = new SparkContext("local", "test", conf) - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { override def killTask( taskId: Long, @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val singleTask = new ShuffleMapTask(0, 0, null, new Partition { @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation", "true") var killTaskCalled = false - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler.stop() sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720c..5d75f5835bf9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -39,6 +39,7 @@ class SessionStateSuite extends SparkFunSuite protected var activeSession: SparkSession = _ override def beforeAll(): Unit = { + super.beforeAll() activeSession = SparkSession.builder().master("local").getOrCreate() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 735e07c21373a..e1022e377132c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { private var targetPartitionSchema: StructType = _ override def beforeAll(): Unit = { + super.beforeAll() targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) targetPartitionSchema = new StructType() .add("b", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..e6c7648c986ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,25 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { + + /** + * Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests. + * That trait initializes the spark session in its [[beforeAll()]] implementation before the + * automatic thread snapshot is performed, so the audit code could fail to report threads leaked + * by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + doThreadPostAudit() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 8a7423663f28d..a80db765846e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { + override protected val enableAutoThreadAudit = false private var sc: SparkContext = null private var hc: HiveContext = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 958ad3e1c3ce8..f7da3c4cbb0aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -30,6 +30,7 @@ class HiveSessionStateSuite extends SessionStateSuite override def beforeAll(): Unit = { // Reuse the singleton session + super.beforeAll() activeSession = spark } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 21b3e281490cf..10204f4694663 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -44,6 +44,8 @@ class HiveSparkSubmitSuite with BeforeAndAfterEach with ResetSystemProperties { + override protected val enableAutoThreadAudit = false + // TODO: rewrite these or mark them as slow tests to be run sparingly override def beforeEach() { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index ce53acef51503..a5dfd89b3a574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -67,6 +67,7 @@ class HiveClientSuite(version: String) } override def beforeAll() { + super.beforeAll() client = init(true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 951ebfad4590e..bb8a4697b0a13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.HiveUtils private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + override protected val enableAutoThreadAudit = false protected var client: HiveClient = null protected def buildClient(hadoopConf: Configuration): HiveClient = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index e64389e56b5a1..72536b833481a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -50,6 +50,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + override protected val enableAutoThreadAudit = false + import HiveClientBuilder.buildClient /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cee82cda4628a..272e6f51f5002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { + override protected val enableAutoThreadAudit = false + /** * Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index f87162f94c01a..a1f054b8e3f44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ + override protected val enableAutoThreadAudit = false override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index df7988f542b71..d3fff37c3424d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + override protected val enableAutoThreadAudit = false protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = From 4371466b3f06ca171b10568e776c9446f7bae6dd Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 16 Jan 2018 12:56:57 -0800 Subject: [PATCH 661/712] [SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator. ## What changes were proposed in this pull request? RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #20229 from MrBago/rFormula. --- R/pkg/R/mllib_utils.R | 1 - .../apache/spark/ml/feature/RFormula.scala | 20 +++++-- .../spark/ml/feature/RFormulaSuite.scala | 53 +++++++++++-------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 23dda42c325be..a53c92c2c4815 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,4 +130,3 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f384ffbf578bc..1155ea5fdd85b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + val oneHotEncodeColumns = ArrayBuffer[(String, String)]() val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() @@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - var encoder = new OneHotEncoder() - .setInputCol(indexed(term)) - .setOutputCol(encodedCol) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoder = encoder.setDropLast(false) + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(Array(indexed(term))) + .setOutputCols(Array(encodedCol)) + .setDropLast(false) keepReferenceCategory = true + } else { + oneHotEncodeColumns += indexed(term) -> encodedCol } - encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => @@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) interactionCol } + if (oneHotEncodeColumns.nonEmpty) { + val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(inputCols) + .setOutputCols(outputCols) + .setDropLast(true) + } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f3f4b5a3d0233..bfe38d32dd77d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ + def testRFormulaTransform[A: Encoder]( + dataframe: DataFrame, + formulaModel: RFormulaModel, + expected: DataFrame): Unit = { + val (first +: rest) = expected.schema.fieldNames.toSeq + val expectedRows = expected.collect() + testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows === expectedRows) + } + } + test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("features column already exists") { @@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("encodes string terms") { @@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("encodes string terms with string indexer order type") { @@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected(idx).collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } } @@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("formula w/o intercept, we should output reference category when encoding string terms") { @@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result1.schema.toString == resultSchema1.toString) - assert(result1.collect() === expected1.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( @@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result2.schema.toString == resultSchema2.toString) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( @@ -302,7 +313,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -310,7 +320,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(String, String, Int)](original, model, expected) } test("force to index label even it is numeric type") { @@ -319,7 +329,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = spark.createDataFrame( Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), @@ -327,7 +336,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Double, String, Int)](original, model, expected) } test("attribute generation") { @@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, String)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[SparkException] { formula1.fit(df1).transform(df2).collect() } - val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) - val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + val model1 = formula1.setHandleInvalid("skip").fit(df1) + val model2 = formula1.setHandleInvalid("keep").fit(df1) val expected1 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), @@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result1.collect() === expected1.collect()) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String)](df2, model1, expected1) + testRFormulaTransform[(Int, String, String)](df2, model2, expected2) // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") intercept[SparkException] { formula2.fit(df1).transform(df2).collect() } - val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) - val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + val model3 = formula2.setHandleInvalid("skip").fit(df1) + val model4 = formula2.setHandleInvalid("keep").fit(df1) val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), @@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") - assert(result3.collect() === expected3.collect()) - assert(result4.collect() === expected4.collect()) + testRFormulaTransform[(Int, String, String)](df2, model3, expected3) + testRFormulaTransform[(Int, String, String)](df2, model4, expected4) } test("Use Vectors as inputs to formula.") { From 5ae333391bd73331b5b90af71a3de52cdbb24109 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 16 Jan 2018 16:25:10 -0800 Subject: [PATCH 662/712] [SPARK-23044] Error handling for jira assignment ## What changes were proposed in this pull request? * If there is any error while trying to assign the jira, prompt again * Filter out the "Apache Spark" choice * allow arbitrary user ids to be entered ## How was this patch tested? Couldn't really test the error case, just some testing of similar-ish code in python shell. Haven't run a merge yet. Author: Imran Rashid Closes #20236 from squito/SPARK-23044. --- dev/merge_spark_pr.py | 50 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 57ca8400b6f3d..6b244d8184b2c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import traceback import urllib2 try: @@ -298,24 +299,37 @@ def choose_jira_assignee(issue, asf_jira): Prompt the user to choose who to assign the issue to in jira, given a list of candidates, including the original reporter and all commentors """ - reporter = issue.fields.reporter - commentors = map(lambda x: x.author, issue.fields.comment.comments) - candidates = set(commentors) - candidates.add(reporter) - candidates = list(candidates) - print("JIRA is unassigned, choose assignee") - for idx, author in enumerate(candidates): - annotations = ["Reporter"] if author == reporter else [] - if author in commentors: - annotations.append("Commentor") - print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - assignee = raw_input("Enter number of user to assign to (blank to leave unassigned):") - if assignee == "": - return None - else: - assignee = candidates[int(assignee)] - asf_jira.assign_issue(issue.key, assignee.key) - return assignee + while True: + try: + reporter = issue.fields.reporter + commentors = map(lambda x: x.author, issue.fields.comment.comments) + candidates = set(commentors) + candidates.add(reporter) + candidates = list(candidates) + print("JIRA is unassigned, choose assignee") + for idx, author in enumerate(candidates): + if author.key == "apachespark": + continue + annotations = ["Reporter"] if author == reporter else [] + if author in commentors: + annotations.append("Commentor") + print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) + raw_assignee = raw_input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") + if raw_assignee == "": + return None + else: + try: + id = int(raw_assignee) + assignee = candidates[id] + except: + # assume it's a user id, and try to assign (might fail, we just prompt again) + assignee = asf_jira.user(raw_assignee) + asf_jira.assign_issue(issue.key, assignee.key) + return assignee + except: + traceback.print_exc() + print("Error assigning JIRA, try again (or leave blank and fix manually)") def resolve_jira_issues(title, merge_branches, comment): From 0c2ba427bc7323729e6ffb34f1f06a97f0bf0c1d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 17 Jan 2018 09:57:30 +0800 Subject: [PATCH 663/712] [SPARK-23095][SQL] Decorrelation of scalar subquery fails with java.util.NoSuchElementException ## What changes were proposed in this pull request? The following SQL involving scalar correlated query returns a map exception. ``` SQL SELECT t1a FROM t1 WHERE t1a = (SELECT count(*) FROM t2 WHERE t2c = t1c HAVING count(*) >= 1) ``` ``` SQL key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) java.util.NoSuchElementException: key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$.org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$evalSubqueryOnZeroTups(subquery.scala:378) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:430) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:426) ``` In this case, after evaluating the HAVING clause "count(*) > 1" statically against the binding of aggregtation result on empty input, we determine that this query will not have a the count bug. We should simply return the evalSubqueryOnZeroTups with empty value. (Please fill in changes proposed in this fix) ## How was this patch tested? A new test was added in the Subquery bucket. Author: Dilip Biswal Closes #20283 from dilipbiswal/scalar-count-defect. --- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../scalar-subquery-predicate.sql | 10 ++++ .../scalar-subquery-predicate.sql.out | 57 ++++++++++++------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2673bea648d09..709db6d8bec7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -369,13 +369,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + case _ => + sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) + resultMap.getOrElse(plan.output.head.exprId, None) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index fb0d07fbdace7..1661209093fc4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a) HAVING count(*) >= 0) OR t1i > '2014-12-31'; +-- TC 02.03.01 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31'; + -- TC 02.04 -- t1 on the right of an outer join -- can be reduced to inner join diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 8b29300e71f90..a2b86db3e4f4c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -293,6 +293,21 @@ val1d -- !query 19 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31' +-- !query 19 schema +struct +-- !query 19 output +val1c +val1d + +-- !query 22 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -300,13 +315,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 19 schema +-- !query 22 schema struct --- !query 19 output +-- !query 22 output 7 --- !query 20 +-- !query 23 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -317,14 +332,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 20 schema +-- !query 23 schema struct --- !query 20 output +-- !query 23 output val1b val1c --- !query 21 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -338,14 +353,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 21 schema +-- !query 24 schema struct --- !query 21 output +-- !query 24 output val1b val1c --- !query 22 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -359,9 +374,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 25 schema struct --- !query 22 output +-- !query 25 output val1a val1a val1b @@ -372,7 +387,7 @@ val1d val1d --- !query 23 +-- !query 26 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -386,16 +401,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 26 schema struct --- !query 23 output +-- !query 26 output val1a val1b val1c val1d --- !query 24 +-- !query 27 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -409,13 +424,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 27 schema struct --- !query 24 output +-- !query 27 output val1a --- !query 25 +-- !query 28 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -423,8 +438,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 28 schema struct --- !query 25 output +-- !query 28 output val1b val1c From a9b845ebb5b51eb619cfa7d73b6153024a6a420d Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 17 Jan 2018 10:03:25 +0800 Subject: [PATCH 664/712] [SPARK-22361][SQL][TEST] Add unit test for Window Frames ## What changes were proposed in this pull request? There are already quite a few integration tests using window frames, but the unit tests coverage is not ideal. In this PR the already existing tests are reorganized, extended and where gaps found additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20019 from gaborgsomogyi/SPARK-22361. --- .../parser/ExpressionParserSuite.scala | 57 ++- .../sql/DataFrameWindowFramesSuite.scala | 405 ++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 243 ----------- 3 files changed, 454 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2b9783a3295c6..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) @@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest { "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", WindowExpression('sum.function('product + 1), WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + } + + test("range/rows window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } - // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", -Literal(10), CurrentRow), + // No between combinations + ("unbounded preceding", UnboundedPreceding, CurrentRow), ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("10 preceding", -Literal(10), CurrentRow), + ("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow), + ("0 preceding", -Literal(0), CurrentRow), + ("current row", CurrentRow, CurrentRow), + ("0 following", Literal(0), CurrentRow), ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), - ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("10 following", Literal(10), CurrentRow), + ("2147483649 following", Literal(2147483649L), CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + + // Between combinations + ("between unbounded preceding and 5 following", + UnboundedPreceding, Literal(5)), + ("between unbounded preceding and 3 + 1 following", + UnboundedPreceding, Add(Literal(3), Literal(1))), + ("between unbounded preceding and 2147483649 following", + UnboundedPreceding, Literal(2147483649L)), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), + ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow), + ("between 0 preceding and current row", -Literal(0), CurrentRow), + ("between current row and current row", CurrentRow, CurrentRow), + ("between current row and 0 following", CurrentRow, Literal(0)), ("between current row and 5 following", CurrentRow, Literal(5)), - ("between 10 preceding and 5 following", -Literal(10), Literal(5)) + ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))), + ("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)), + ("between current row and unbounded following", CurrentRow, UnboundedFollowing), + ("between 2147483648 preceding and unbounded following", + -Literal(2147483648L), UnboundedFollowing), + ("between 10 preceding and unbounded following", + -Literal(10), UnboundedFollowing), + ("between 3 + 1 preceding and unbounded following", + -Add(Literal(3), Literal(1)), UnboundedFollowing), + ("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing), + + // Between partial and full range + ("between 10 preceding and 5 following", -Literal(10), Literal(5)), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing) ) frameTypes.foreach { case (frameTypeSql, frameType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala new file mode 100644 index 0000000000000..0ee9b0edc02b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Window frame testing for DataFrame API. + */ +class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("lead/lag with empty data frame") { + val df = Seq.empty[(Int, String)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + lead("value", 1).over(window), + lag("value", 1).over(window)), + Nil) + } + + test("lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) + } + + test("reverse lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) + } + + test("lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + } + + test("reverse lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) + } + + test("lead/lag with default value") { + val default = "n/a" + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 2, default).over(window), + lag("value", 2, default).over(window), + lead("value", -2, default).over(window), + lag("value", -2, default).over(window)), + Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: + Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: + Row(2, default, default, default, default) :: Nil) + } + + test("rows/range between with empty data frame") { + val df = Seq.empty[(String, Int)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + first("value").over( + window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + first("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Nil) + } + + test("rows between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept at most one ORDER BY expression when unbounded") { + val df = Seq((1, 1)).toDF("key", "value") + val window = Window.orderBy($"key", $"value") + + checkAnswer( + df.select( + $"key", + min("key").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1)) + ) + + val e1 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e2 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e3 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + } + + test("range between should accept numeric values only when bounded") { + val df = Seq("non_numeric").toDF("value") + val window = Window.orderBy($"value") + + checkAnswer( + df.select( + $"value", + min("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) + + val e1 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("The data type of the upper bound 'string' " + + "does not match the expected data type")) + + val e2 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + + val e3 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) + + checkAnswer( + df2.select( + $"key", + count("key").over(window)), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), + (100.001D, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + + test("unbounded rows/range between with aggregation") { + val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + sum("value").over(window. + rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + sum("value").over(window. + rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) + } + + test("unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: + Row(4, 4, 4) :: Nil) + } + + test("reverse unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc) + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: + Row(1, 1, 1) :: Nil) + } + + test("unbounded preceding/following range between with aggregation") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy("value").orderBy("key") + + checkAnswer( + df.select( + $"key", + avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) + .as("avg_key1"), + avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) + .as("avg_key2")), + Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: + Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: + Row(6, 17.0d / 4.0d, 6.0d) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse preceding/following range between with aggregation") { + val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") + val window = Window.orderBy($"value".desc) + + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), + sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: + Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("reverse sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("sliding range between with aggregation") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: + Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) + } + + test("reverse sliding range between with aggregation") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window.partitionBy($"category").orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01c988ecc3726..281147835abde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -55,56 +55,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("Window.rowsBetween") { - val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") - // Running (cumulative) sum - checkAnswer( - df.select('key, sum("value").over( - Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row("one", 1) :: Row("two", 3) :: Nil - ) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) - } - test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -136,199 +86,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) - } - - test("aggregation and range between") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), - Row(2.0d), Row(2.0d))) - } - - test("row between should accept integer values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - checkAnswer( - df2.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), - (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(2.5D)))), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, - lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("aggregation and rows between with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.currentRow)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), - Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), - Row(4, 4, 4, 4))) - } - - test("aggregation and range between with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - Seq(Row(3, null, 3.0d, 4.0d, 3.0d), - Row(5, false, 4.0d, 5.0d, 5.0d), - Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), - Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), - Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), - Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) - } - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - } - test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") From 16670578519a7b787b0c63888b7d2873af12d5b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 18:11:27 -0800 Subject: [PATCH 665/712] [SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader ## What changes were proposed in this pull request? The Kafka reader is now interruptible and can close itself. ## How was this patch tested? I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine. Author: Jose Torres Closes #20253 from jose-torres/fix-data-reader. --- .../sql/kafka010/KafkaContinuousReader.scala | 260 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 ++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 476 ++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 94 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 539 +++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 4 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1628 insertions(+), 416 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..fc977977504f7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.TimeoutException + +import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false + // Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving + // interrupt points to end the query rather than waiting for new data that might never come. + try { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs, + failOnDataLoss) + } catch { + // We didn't read within the timeout. We're supposed to block indefinitely for new data, so + // swallow and ignore this. + case _: TimeoutException => + + // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, + // or if it's the endpoint of the data range (i.e. the "true" next offset). + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + val range = consumer.getAvailableOffsetRange() + if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { + // retry + } else { + throw e + } + } + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 864a92b8f813f..169a5d006fb04 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..8487a69851237 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..5a1a14f7a307a --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + // Continuous processing tasks end asynchronously, so test that they actually end. + private val tasksEndedListener = new SparkListener() { + val activeTaskIdCount = new AtomicInteger(0) + + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + activeTaskIdCount.incrementAndGet() + } + + override def onTaskEnd(end: SparkListenerTaskEnd): Unit = { + activeTaskIdCount.decrementAndGet() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + spark.sparkContext.addSparkListener(tasksEndedListener) + } + + override def afterEach(): Unit = { + eventually(timeout(streamingTimeout)) { + assert(tasksEndedListener.activeTaskIdCount.get() == 0) + } + spark.sparkContext.removeSparkListener(tasksEndedListener) + super.afterEach() + } + + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index a0f5695fc485c..1acff61e11d2a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (query.get.isActive) { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + query match { // Make sure no Spark job is running when deleting a topic - query.get.processAllAvailable() + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => } val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap @@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") .option("subscribe", topic) + .load() - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } - test("maxOffsetsPerTrigger") { + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true } - if (q.exception.isDefined) { - throw q.exception.get + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) } - true - } - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("delete a topic when a Spark job is running") { - KafkaSourceSuite.collectedData.clear() - - val topic = newTopic() - testUtils.createTopic(topic, partitions = 1) - testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribe", topic) - // If a topic is deleted and we try to poll data starting from offset 0, - // the Kafka consumer will just block until timeout and return an empty result. - // So set the timeout to 1 second to make this test fast. - .option("kafkaConsumer.pollTimeoutMs", "1000") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - KafkaSourceSuite.globalTestUtils = testUtils - // The following ForeachWriter will delete the topic before fetching data from Kafka - // in executors. - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { - override def open(partitionId: Long, version: Long): Boolean = { - KafkaSourceSuite.globalTestUtils.deleteTopic(topic) - true - } - - override def process(value: Int): Unit = { - KafkaSourceSuite.collectedData.add(value) - } - - override def close(errorOrNull: Throwable): Unit = {} - }).start() - query.processAllAvailable() - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - assert(query.exception.isEmpty) - } - test("get offsets from case insensitive parameters") { for ((optionKey, optionValue, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), @@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() - query.processAllAvailable() - val rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + var rows: Array[Row] = Array() + eventually(timeout(streamingTimeout)) { + rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + } val row = rows(0) assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") @@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index b3f1a1a1aaab3..66eb42d4658f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -177,6 +176,7 @@ class DataReaderThread( private[continuous] var failureReason: Throwable = _ override def run(): Unit = { + TaskContext.setTaskContext(context) val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) try { while (!context.isInterrupted && !context.isCompleted()) { @@ -201,6 +201,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 50345a2aa59741c511d555edbbad2da9611e7d16 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 22:14:47 -0800 Subject: [PATCH 666/712] Revert "[SPARK-23020][CORE] Fix races in launcher code, test." This reverts commit 66217dac4f8952a9923625908ad3dcb030763c81. --- .../spark/launcher/SparkLauncherSuite.java | 49 +++++++------------ .../spark/launcher/AbstractAppHandle.java | 22 ++------- .../spark/launcher/ChildProcAppHandle.java | 18 +++---- .../spark/launcher/InProcessAppHandle.java | 17 +++---- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 +++-------------- .../org/apache/spark/launcher/BaseSuite.java | 42 +++------------- .../spark/launcher/LauncherServerSuite.java | 20 +++++--- 8 files changed, 72 insertions(+), 156 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -32,7 +31,6 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; -import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -139,9 +137,7 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { - assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); - }); + TimeUnit.MILLISECONDS.sleep(500); } } @@ -150,35 +146,26 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - synchronized (transitions) { - transitions.add(h.getState()); - } + transitions.add(h.getState()); return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = null; - try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); - } finally { - if (handle != null) { - handle.kill(); - } - } + SparkAppHandle handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..df1e7316861d4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private volatile boolean disposed; + private boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,7 +70,8 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { + if (!disposed) { + disposed = true; if (connection != null) { try { connection.close(); @@ -78,7 +79,7 @@ public synchronized void disconnect() { // no-op. } } - dispose(); + server.unregister(this); } } @@ -94,21 +95,6 @@ boolean isDisposed() { return disposed; } - /** - * Mark the handle as disposed, and set it as LOST in case the current state is not final. - */ - synchronized void dispose() { - if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. - server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } - this.disposed = true; - } - } - void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..8b3f427b7750e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,16 +48,14 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - if (!isDisposed()) { - setState(State.KILLED); - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); - } - childProc = null; + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); } + childProc = null; } + setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -96,6 +94,8 @@ void monitorChild() { return; } + disconnect(); + int ec; try { ec = proc.exitValue(); @@ -118,8 +118,6 @@ void monitorChild() { if (newState != null) { setState(newState, true); } - - disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,16 +39,15 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - if (!isDisposed()) { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - setState(State.KILLED); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); - } + LOG.warning("kill() may leave the underlying app running in in-process mode."); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); } + + setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index fd6f229b2349c..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public synchronized void close() throws IOException { + public void close() throws IOException { if (!closed) { - closed = true; - socket.close(); + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } } } - boolean isOpen() { - return !closed; - } - } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 660c4443b20b9..b8999a1d7a4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,33 +217,6 @@ void unregister(AbstractAppHandle handle) { break; } } - - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -315,7 +288,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - volatile AbstractAppHandle handle; + private AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -365,21 +338,16 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { - if (!isOpen()) { - return; - } - synchronized (clients) { clients.remove(this); } - - synchronized (this) { - super.close(); - notifyAll(); - } - + super.close(); if (handle != null) { - handle.dispose(); + if (!handle.getState().isFinal()) { + LOG.log(Level.WARNING, "Lost connection to spark application."); + handle.setState(SparkAppHandle.State.LOST); + } + handle.disconnect(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..3e1a90eae98d4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -48,46 +47,19 @@ public void postChecks() { assertNull(server); } - protected void waitFor(final SparkAppHandle handle) throws Exception { + protected void waitFor(SparkAppHandle handle) throws Exception { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); try { - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is not in final state.", handle.getState().isFinal()); - }); + while (!handle.getState().isFinal()) { + assertTrue("Timed out waiting for handle to transition to final state.", + System.nanoTime() < deadline); + TimeUnit.MILLISECONDS.sleep(10); + } } finally { if (!handle.getState().isFinal()) { handle.kill(); } } - - // Wait until the handle has been marked as disposed, to make sure all cleanup tasks - // have been performed. - AbstractAppHandle ahandle = (AbstractAppHandle) handle; - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); - }); - } - - /** - * Call a closure that performs a check every "period" until it succeeds, or the timeout - * elapses. - */ - protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { - assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); - long deadline = System.nanoTime() + timeout.toNanos(); - int count = 0; - while (true) { - try { - count++; - check.run(); - return; - } catch (Throwable t) { - if (System.nanoTime() >= deadline) { - String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); - throw new IllegalStateException(msg, t); - } - Thread.sleep(period.toMillis()); - } - } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 75c1af0c71e2a..7e2b09ce25c9b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,14 +23,12 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; -import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -199,20 +197,28 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - final AtomicBoolean helloSent = new AtomicBoolean(); - eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { try { - if (!helloSent.get()) { + if (!helloSent) { client.send(new Hello(secret, "1.4.0")); - helloSent.set(true); + helloSent = true; } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } } - }); + } } private static class TestClient extends LauncherConnection { From a963980a6d2b4bef2c546aa33acf0aa501d2507b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 22:27:28 -0800 Subject: [PATCH 667/712] Fix merge between 07ae39d0ec and 1667057851 ## What changes were proposed in this pull request? The first commit added a new test, and the second refactored the class the test was in. The automatic merge put the test in the wrong place. ## How was this patch tested? - Author: Jose Torres Closes #20289 from jose-torres/fix. --- .../apache/spark/sql/kafka010/KafkaSourceSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 1acff61e11d2a..62f6a34a6b67a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -479,11 +479,6 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { // `failOnDataLoss` is `false`, we should not fail the query assert(query.exception.isEmpty) } -} - -class KafkaSourceSuiteBase extends KafkaSourceTest { - - import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -549,6 +544,11 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { CheckLastBatch(120 to 124: _*) ) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() From a0aedb0ded4183cc33b27e369df1cbf862779e26 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 14:32:18 +0800 Subject: [PATCH 668/712] [SPARK-23072][SQL][TEST] Add a Unicode schema test for file-based data sources ## What changes were proposed in this pull request? After [SPARK-20682](https://github.com/apache/spark/pull/19651), Apache Spark 2.3 is able to read ORC files with Unicode schema. Previously, it raises `org.apache.spark.sql.catalyst.parser.ParseException`. This PR adds a Unicode schema test for CSV/JSON/ORC/Parquet file-based data sources. Note that TEXT data source only has [a single column with a fixed name 'value'](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L71). ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #20266 from dongjoon-hyun/SPARK-23072. --- .../spark/sql/FileBasedDataSourceSuite.scala | 81 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ---- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 ---- .../sql/hive/execution/SQLQuerySuite.scala | 8 -- 4 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala new file mode 100644 index 0000000000000..22fb496bc838e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + + allFileBasedDataSources.foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempPath { dir => + Seq("str").toDS().limit(0).write.format(format).save(dir.getCanonicalPath) + } + } + } + + // `TEXT` data source always has a single column whose name is `value`. + allFileBasedDataSources.filterNot(_ == "text").foreach { format => + test(s"SPARK-23072 Write and read back unicode column names - $format") { + withTempPath { path => + val dir = path.getCanonicalPath + + // scalastyle:off nonascii + val df = Seq("a").toDF("한글") + // scalastyle:on nonascii + + df.write.format(format).option("header", "true").save(dir) + val answerDf = spark.read.format(format).option("header", "true").load(dir) + + assert(df.schema.sameType(answerDf.schema)) + checkAnswer(df, answerDf) + } + } + } + + // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. + // `TEXT` data source always has a single column whose name is `value`. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF().limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-22146 read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 96bf65fce9c4a..7c9840a34eaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,20 +2757,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - // Only New OrcFileFormat supports this - Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, - "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { - withTempPath { file => - val path = file.getCanonicalPath - val emptyDf = Seq((true, 1, "str")).toDF.limit(0) - emptyDf.write.format(format).save(path) - - val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) - checkAnswer(df, emptyDf) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c8caba83bf365..fade143a1755e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,14 +23,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -1344,18 +1342,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"SPARK-22146: read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" - withTempDir { dir => - val tmpFile = s"$dir/$nameWithSpecialChars" - spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - val fileContent = spark.read.format(format).load(tmpFile) - checkAnswer(fileContent, Seq(Row("a"), Row("b"))) - } - } - } - private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47adc77a52d51..33bcae91fdaf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2159,12 +2159,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"Writing empty datasets should not fail - $format") { - withTempDir { dir => - Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") - } - } - } } From 1f3d933e0bd2b1e934a233ed699ad39295376e71 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 17 Jan 2018 16:01:41 +0800 Subject: [PATCH 669/712] [SPARK-23062][SQL] Improve EXCEPT documentation ## What changes were proposed in this pull request? Make the default behavior of EXCEPT (i.e. EXCEPT DISTINCT) more explicit in the documentation, and call out the change in behavior from 1.x. Author: Henry Robinson Closes #20254 from henryr/spark-23062. --- R/pkg/R/DataFrame.R | 2 +- python/pyspark/sql/dataframe.py | 3 ++- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6caa125e1e14a..29f3e986eaab6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2853,7 +2853,7 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT DISTINCT} in SQL. #' #' @param x a SparkDataFrame. #' @param y a SparkDataFrame. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 95eca76fa9888..2d5e9b91468cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1364,7 +1364,8 @@ def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. - This is equivalent to `EXCEPT` in SQL. + This is equivalent to `EXCEPT DISTINCT` in SQL. + """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 34f0ab5aa6699..912f411fa3845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1903,7 +1903,7 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. + * This is equivalent to `EXCEPT DISTINCT` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. From 0f8a28617a0742d5a99debfbae91222c2e3b5cec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 21:53:36 +0800 Subject: [PATCH 670/712] [SPARK-21783][SQL] Turn on ORC filter push-down by default ## What changes were proposed in this pull request? ORC filter push-down is disabled by default from the beginning, [SPARK-2883](https://github.com/apache/spark/commit/aa31e431fc09f0477f1c2351c6275769a31aca90#diff-41ef65b9ef5b518f77e2a03559893f4dR149 ). Now, Apache Spark starts to depend on Apache ORC 1.4.1. For Apache Spark 2.3, this PR turns on ORC filter push-down by default like Parquet ([SPARK-9207](https://issues.apache.org/jira/browse/SPARK-21783)) as a part of [SPARK-20901](https://issues.apache.org/jira/browse/SPARK-20901), "Feature parity for ORC with Parquet". ## How was this patch tested? Pass the existing tests. Author: Dongjoon Hyun Closes #20265 from dongjoon-hyun/SPARK-21783. --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/FilterPushdownBenchmark.scala | 243 ++++++++++++++++++ 2 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6746fbcaf2483..16fbb0c3e9e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -410,7 +410,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..c6dd7dadc9d93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + conf.set("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("FilterPushdownBenchmark") + .config(conf) + .getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("id", monotonically_increasing_id()) + + val dirORC = dir.getCanonicalPath + "/orc" + val dirParquet = dir.getCanonicalPath + "/parquet" + + df.write.mode("overwrite").orc(dirORC) + df.write.mode("overwrite").parquet(dirParquet) + + spark.read.orc(dirORC).createOrReplaceTempView("orcTable") + spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X + Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X + Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X + Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X + + Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X + Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X + Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X + Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X + + Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X + Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X + Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X + Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X + + Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X + Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X + Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X + + Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X + Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X + Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X + Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X + + Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X + Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X + Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X + + Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X + Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X + Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X + Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X + + Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X + Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X + Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X + Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X + + Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X + Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X + Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X + Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X + + Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X + Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X + Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X + Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X + + Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X + Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X + Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X + Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X + + Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X + Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X + Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X + Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + val mid = numRows / 2 + + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width) + + Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => + val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"id = $mid", + s"id <=> $mid", + s"$mid <= id AND id <= $mid", + s"${mid - 1} < id AND id < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% rows (id < ${numRows * percent / 100})", + s"id < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + } + } +} From 8598a982b4147abe5f1aae005fea0fd5ae395ac4 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 18 Jan 2018 00:05:26 +0800 Subject: [PATCH 671/712] [SPARK-23079][SQL] Fix query constraints propagation with aliases ## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20278 from gengliangwang/FixConstraintSimple. --- .../catalyst/plans/logical/LogicalPlan.scala | 1 + .../plans/logical/QueryPlanConstraints.scala | 37 +----------- .../InferFiltersFromConstraintsSuite.scala | 59 +------------------ .../plans/ConstraintPropagationSuite.scala | 2 + .../org/apache/spark/sql/SQLQuerySuite.scala | 11 ++++ 5 files changed, 17 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ff2a0ec588567..c8ccd9bd03994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 9c0a30a47f839..5c7b8e5b97883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -94,25 +94,16 @@ trait QueryPlanConstraints { self: LogicalPlan => case _ => Seq.empty[Attribute] } - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) - // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - aliasedConstraints.foreach { + constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = aliasedConstraints - eq + val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference @@ -120,30 +111,6 @@ trait QueryPlanConstraints { self: LogicalPlan => inferredConstraints -- constraints } - /** - * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. - * Thus non-converging inference can be prevented. - * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. - * Also, the size of constraints is reduced without losing any information. - * When the inferred filters are pushed down the operators that generate the alias, - * the alias names used in filters are replaced by the aliased expressions. - */ - private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) - : Set[Expression] = { - val attributesInEqualTo = constraints.flatMap { - case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil - case _ => Nil - } - var aliasedConstraints = constraints - attributesInEqualTo.foreach { a => - if (aliasMap.contains(a)) { - val child = aliasMap.get(a).get - aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) - } - } - aliasedConstraints - } - private def replaceConstraints( constraints: Set[Expression], source: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a0708bf7eee9a..178c4b8c270a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushDownPredicate, InferFiltersFromConstraints, CombineFilters, + SimplifyBinaryComparison, BooleanSimplification) :: Nil } @@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join with alias: don't generate constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through - // the constraint inference procedure. - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - // We hide an `Alias` inside the child's child's expressions, to cover the situation reported - // in [SPARK-20700]. - .select('int_col, 'd, 'a).as("t") - .join(t2, Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) - && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) - && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - .select('int_col, 'd, 'a).as("t") - .join( - t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && - 'a === Coalesce(Seq('a, 'a))), - Inner, - Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - - test("inner join with EqualTo expressions containing part of each other: don't generate " + - "constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating - // complicated constraints through the constraint inference procedure. - val originalQuery = t1 - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .where('a === 'd && 'c === 'e) - .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && - 'c === Coalesce(Seq('a, 'b))) - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .join(t2.where(IsNotNull('a) && IsNotNull('c)), - Inner, - Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 866ff0d33cbb2..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c9840a34eaa3..d4d0aa4f5f5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23079: constraints should be inferred correctly with aliases") { + withTable("t") { + spark.range(5).write.saveAsTable("t") + val t = spark.read.table("t") + val left = t.withColumn("xid", $"id" + lit(1)).as("x") + val right = t.withColumnRenamed("id", "xid").as("y") + val df = left.join(right, "xid").filter("id = 3").toDF() + checkAnswer(df, Row(4, 3)) + } + } + test("SRARK-22266: the same aggregate function was calculated multiple times") { val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" val df = sql(query) From c132538a164cd8b55dbd7e8ffdc0c0782a0b588c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 17 Jan 2018 09:27:49 -0800 Subject: [PATCH 672/712] [SPARK-23020] Ignore Flaky Test: SparkLauncherSuite.testInProcessLauncher ## What changes were proposed in this pull request? Temporarily ignoring flaky test `SparkLauncherSuite.testInProcessLauncher` to de-flake the builds. This should be re-enabled when SPARK-23020 is merged. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20291 from sameeragarwal/disable-test-2. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..dffa609f1cbdf 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -25,6 +25,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; @@ -120,7 +121,8 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - @Test + // TODO: [SPARK-23020] Re-enable this + @Ignore public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original From 86a845031824a5334db6a5299c6f5dcc982bc5b8 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:52:51 -0800 Subject: [PATCH 673/712] [SPARK-23033][SS] Don't use task level retry for continuous processing ## What changes were proposed in this pull request? Continuous processing tasks will fail on any attempt number greater than 0. ContinuousExecution will catch these failures and restart globally from the last recorded checkpoints. ## How was this patch tested? unit test Author: Jose Torres Closes #20225 from jose-torres/no-retry. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +-- .../ContinuousDataSourceRDDIter.scala | 5 ++ .../continuous/ContinuousExecution.scala | 2 +- .../ContinuousTaskRetryException.scala | 26 +++++++ .../spark/sql/streaming/StreamTest.scala | 9 ++- .../continuous/ContinuousSuite.scala | 71 +++++++++++-------- 6 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 62f6a34a6b67a..27dbb3f7a8f31 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { val query = kafka .writeStream .format("memory") - .outputMode("append") .queryName("kafkaColumnTypes") .trigger(defaultTrigger) .start() - var rows: Array[Row] = Array() eventually(timeout(streamingTimeout)) { - rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + assert(spark.table("kafkaColumnTypes").count == 1, + s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}") } - val row = rows(0) + val row = spark.table("kafkaColumnTypes").head() assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 66eb42d4658f6..dcb3b54c4e160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,6 +52,11 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..45b794c70a50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,7 +24,7 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala new file mode 100644 index 0000000000000..e0a6f6dd50bb3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.SparkException + +/** + * An exception thrown when a continuous processing task runs with a nonzero attempt ID. + */ +class ContinuousTaskRetryException + extends SparkException("Continuous execution does not support task retry", null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..c75247e0f6ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } + s.lastExecution.executedPlan // will fail if lastExecution is null + } case _ => } } catch { @@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9562c10feafe9..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,36 +17,18 @@ package org.apache.spark.sql.streaming.continuous -import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} -import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} +import java.util.UUID -import scala.reflect.ClassTag -import scala.util.control.ControlThrowable - -import com.google.common.util.concurrent.UncheckedExecutionException -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkContext, SparkEnv} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.TestSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class ContinuousSuiteBase extends StreamTest { // We need more than the default local[2] to be able to schedule all partitions simultaneously. @@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase { StopStream) } + test("task failure kills the query") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(100)), + Execute(waitForRateSourceTriggers(_, 2)), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + ExpectFailure[SparkException] { e => + e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + test("query without test harness") { val df = spark.readStream .format("rate") @@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), - Execute { query => - val data = query.sink.asInstanceOf[MemorySinkV2].allData - val vals = data.map(_.getLong(0)).toSet - assert(scala.Range(0, 25000).forall { i => - vals.contains(i) - }) - }) + StopStream, + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) + ) } test("automatic epoch advancement") { @@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } @@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { StopStream, StartStream(Trigger.Continuous(2012)), AwaitEpoch(50), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } From e946c63dd56d121cf898084ed7e9b5b0868b226e Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:58:44 -0800 Subject: [PATCH 674/712] [SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query. ## What changes were proposed in this pull request? Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination. ## How was this patch tested? new and existing unit tests Author: Jose Torres Closes #20282 from jose-torres/fix-runid. --- .../datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++-- .../execution/streaming/StreamExecution.scala | 3 +-- .../ContinuousDataSourceRDDIter.scala | 10 ++++---- .../continuous/ContinuousExecution.scala | 18 ++++++++----- .../continuous/EpochCoordinator.scala | 9 ++++--- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 25 +++++++++++++++++++ 8 files changed, 54 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 8c64df080242f..beb66738732be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -58,7 +58,8 @@ case class DataSourceV2ScanExec( case _: ContinuousReader => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetReaderPartitions(readTasks.size())) new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) .asInstanceOf[RDD[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..3dbdae7b4df9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) val runTask = writer match { case w: ContinuousWriter => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) (context: TaskContext, iter: Iterator[InternalRow]) => @@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow]): WriterCommitMessage = { val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) val currentMsg: WriterCommitMessage = null var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..e7982d7880ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index dcb3b54c4e160..cd7065f5e6601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -59,7 +59,7 @@ class ContinuousDataSourceRDD( val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() - val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) // This queue contains two types of messages: // * (null, null) representing an epoch boundary. @@ -68,7 +68,7 @@ class ContinuousDataSourceRDD( val epochPollFailed = new AtomicBoolean(false) val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--${runId}--${context.partitionId()}") + s"epoch-poll--$coordinatorId--${context.partitionId()}") val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) epochPollExecutor.scheduleWithFixedDelay( epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) @@ -86,7 +86,7 @@ class ContinuousDataSourceRDD( epochPollExecutor.shutdown() }) - val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) new Iterator[UnsafeRow] { private val POLL_TIMEOUT_MS = 1000 @@ -150,7 +150,7 @@ class EpochPollRunnable( private[continuous] var failureReason: Throwable = _ private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong override def run(): Unit = { @@ -177,7 +177,7 @@ class DataReaderThread( failedFlag: AtomicBoolean) extends Thread( s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { private[continuous] var failureReason: Throwable = _ override def run(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 45b794c70a50a..c0507224f9be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -57,6 +57,9 @@ class ContinuousExecution( @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources + // For use only in test harnesses. + private[sql] var currentEpochCoordinatorId: String = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in StreamExecutionThread " + @@ -149,7 +152,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -219,15 +221,19 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - sparkSession.sparkContext.setLocalProperty( + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) - sparkSession.sparkContext.setLocalProperty( - ContinuousExecution.RUN_ID_KEY, runId.toString) + // Add another random ID on top of the run ID, to distinguish epoch coordinators across + // reconfigurations. + val epochCoordinatorId = s"$runId--${UUID.randomUUID}" + currentEpochCoordinatorId = epochCoordinatorId + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { @@ -359,5 +365,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" - val RUN_ID_KEY = "__run_id" + val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..90b3584aa0436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset( /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { - private def endpointName(runId: String) = s"EpochCoordinator-$runId" + private def endpointName(id: String) = s"EpochCoordinator-$id" /** * Create a reference to a new [[EpochCoordinator]]. @@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging { writer: ContinuousWriter, reader: ContinuousReader, query: ContinuousExecution, + epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( writer, reader, query, startEpoch, session, env.rpcEnv) - val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref } - def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { - val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv) logDebug("Retrieved existing EpochCoordinator endpoint") rpcEndpointRef } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index c75247e0f6ed8..efdb0e0e7cf1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(): AssertOnQuery = Execute { case s: ContinuousExecution => - val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get) .askSync[Long](IncrementAndGetEpoch) s.awaitEpoch(newEpoch - 1) case _ => throw new IllegalStateException("microbatch cannot increment epoch") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288fb..79d65192a14aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("continuous processing listeners should receive QueryTerminatedEvent") { + val df = spark.readStream.format("rate").load() + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append, useV2Sink = true)( + StartStream(Trigger.Continuous(1000)), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() From 4e6f8fb150ae09c7d1de6beecb2b98e5afa5da19 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 07:26:43 +0900 Subject: [PATCH 675/712] [SPARK-23047][PYTHON][SQL] Change MapVector to NullableMapVector in ArrowColumnVector ## What changes were proposed in this pull request? This PR changes usage of `MapVector` in Spark codebase to use `NullableMapVector`. `MapVector` is an internal Arrow class that is not supposed to be used directly. We should use `NullableMapVector` instead. ## How was this patch tested? Existing test. Author: Li Jin Closes #20239 from icexelloss/arrow-map-vector. --- .../sql/vectorized/ArrowColumnVector.java | 13 +++++-- .../vectorized/ArrowColumnVectorSuite.scala | 36 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 708333213f3f1..eb69001fe677e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -247,8 +247,8 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof NullableMapVector) { + NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; @@ -553,9 +553,16 @@ final int getArrayOffset(int rowId) { } } + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined + * in the parent class. Any call to "get" method in this class is a bug in the code. + * + */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(MapVector vector) { + StructAccessor(NullableMapVector vector) { super(vector); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 7304803a092c0..53432669e215d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -322,6 +322,42 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[NullableMapVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) From 45ad97df87c89cb94ce9564e5773897b6d9326f5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 07:30:54 +0900 Subject: [PATCH 676/712] [SPARK-23132][PYTHON][ML] Run doctests in ml.image when testing ## What changes were proposed in this pull request? This PR proposes to actually run the doctests in `ml/image.py`. ## How was this patch tested? doctests in `python/pyspark/ml/image.py`. Author: hyukjinkwon Closes #20294 from HyukjinKwon/trigger-image. --- python/pyspark/ml/image.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f675..2d86c7f03860c 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -194,9 +194,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) >>> df.count() - 4 + 5 .. versionadded:: 2.3.0 """ @@ -216,3 +216,25 @@ def readImages(self, path, recursive=False, numPartitions=-1, def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance + + +def _test(): + import doctest + import pyspark.ml.image + globs = pyspark.ml.image.__dict__.copy() + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.image tests")\ + .getOrCreate() + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.image, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 7823d43ec0e9c4b8284bb4529b0e624c43bc9bb7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 Jan 2018 17:16:57 -0600 Subject: [PATCH 677/712] [MINOR] Fix typos in ML scaladocs ## What changes were proposed in this pull request? Fixed some typos found in ML scaladocs ## How was this patch tested? NA Author: Bryan Cutler Closes #20300 from BryanCutler/ml-doc-typos-MINOR. --- .../src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9bed74a9f2c05..d40827edb6d64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -75,7 +75,7 @@ sealed abstract class SummaryBuilder { * val Row(meanVec) = meanDF.first() * }}} * - * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * Note: Currently, the performance of this interface is about 2x~3x slower than using the RDD * interface. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8826ef3271bc1..88ff0dfd75e96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -93,7 +93,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam @@ -112,7 +112,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * for more information. * * @group expertSetParam - */@Since("2.3.0") + */ + @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") From bac0d661af6092dd26638223156827aceb901229 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:40:02 -0800 Subject: [PATCH 678/712] [SPARK-23119][SS] Minor fixes to V2 streaming APIs ## What changes were proposed in this pull request? - Added `InterfaceStability.Evolving` annotations - Improved docs. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20286 from tdas/SPARK-23119. --- .../v2/streaming/ContinuousReadSupport.java | 2 ++ .../streaming/reader/ContinuousDataReader.java | 2 ++ .../v2/streaming/reader/ContinuousReader.java | 9 +++++++-- .../v2/streaming/reader/MicroBatchReader.java | 5 +++++ .../sources/v2/streaming/reader/Offset.java | 18 +++++++++++++----- .../v2/streaming/reader/PartitionOffset.java | 3 +++ .../sources/v2/writer/DataSourceV2Writer.java | 5 ++++- 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 3136cee1f655f..9a93a806b0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -19,6 +19,7 @@ import java.util.Optional; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; @@ -28,6 +29,7 @@ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability for continuous stream processing. */ +@InterfaceStability.Evolving public interface ContinuousReadSupport extends DataSourceV2 { /** * Creates a {@link ContinuousReader} to scan the data from this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index ca9a290e97a02..3f13a4dbf5793 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -17,11 +17,13 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. */ +@InterfaceStability.Evolving public interface ContinuousDataReader extends DataReader { /** * Get the offset of the current record, or the start offset if no records have been read. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index f0b205869ed6c..745f1ce502443 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; @@ -27,11 +28,15 @@ * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { /** - * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to - * a single global offset. + * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 70ff756806032..02f37cebc7484 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; @@ -25,7 +26,11 @@ /** * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** * Set the desired offset range for read tasks created from this reader. Read tasks will diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 60b87f2ac0756..abba3e7188b13 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -17,12 +17,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; + /** - * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. - * During execution, Offsets provided by the data source implementation will be logged and used as - * restart checkpoints. Sources should provide an Offset implementation which they can use to - * reconstruct the stream position where the offset was taken. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. + * During execution, offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Each source should provide an offset implementation which the source can use + * to reconstruct a position in the stream up to which data has been seen/processed. + * + * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to + * maintain compatibility with DataSource V1 APIs. This extension will be removed once we + * get rid of V1 completely. */ +@InterfaceStability.Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is @@ -37,7 +45,7 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of /** * Equality based on JSON string representation. We leverage the * JSON representation for normalization between the Offset's - * in memory and on disk representations. + * in deserialized and serialized representations. */ @Override public boolean equals(Object obj) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index eca0085c8a8ce..4688b85f49f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -19,11 +19,14 @@ import java.io.Serializable; +import org.apache.spark.annotation.InterfaceStability; + /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will * provide a method to merge these into a global Offset. * * These offsets must be serializable. */ +@InterfaceStability.Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index fc37b9a516f82..317ac45bcfd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -22,11 +22,14 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * From 1002bd6b23ff78a010ca259ea76988ef4c478c6e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:41:43 -0800 Subject: [PATCH 679/712] [SPARK-23064][DOCS][SS] Added documentation for stream-stream joins ## What changes were proposed in this pull request? Added documentation for stream-stream joins ![image](https://user-images.githubusercontent.com/663212/35018744-e999895a-fad7-11e7-9d6a-8c7a73e6eb9c.png) ![image](https://user-images.githubusercontent.com/663212/35018775-157eb464-fad8-11e7-879e-47a2fcbd8690.png) ![image](https://user-images.githubusercontent.com/663212/35018784-27791a24-fad8-11e7-98f4-7ff246f62a74.png) ![image](https://user-images.githubusercontent.com/663212/35018791-36a80334-fad8-11e7-9791-f85efa7c6ba2.png) ## How was this patch tested? N/a Author: Tathagata Das Closes #20255 from tdas/join-docs. --- .../structured-streaming-programming-guide.md | 338 +++++++++++++++++- 1 file changed, 326 insertions(+), 12 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index de13e281916db..1779a4215e085 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1051,7 +1051,19 @@ output mode. ### Join Operations -Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. +Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame +as well as another streaming Dataset/DataFrame. The result of the streaming join is generated +incrementally, similar to the results of streaming aggregations in the previous section. In this +section we will explore what type of joins (i.e. inner, outer, etc.) are supported in the above +cases. Note that in all the supported join types, the result of the join with a streaming +Dataset/DataFrame will be the exactly the same as if it was with a static Dataset/DataFrame +containing the same data in the stream. + + +#### Stream-static joins + +Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some +type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
    @@ -1089,6 +1101,300 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat
    +Note that stream-static joins are not stateful, so no state management is necessary. +However, a few types of stream-static outer joins are not yet supported. +These are listed at the [end of this Join section](#support-matrix-for-joins-in-streaming-queries). + +#### Stream-stream Joins +In Spark 2.3, we have added support for stream-stream joins, that is, you can join two streaming +Datasets/DataFrames. The challenge of generating join results between two data streams is that, +at any point of time, the view of the dataset is incomplete for both sides of the join making +it much harder to find matches between inputs. Any row received from one input stream can match +with any future, yet-to-be-received row from the other input stream. Hence, for both the input +streams, we buffer past input as streaming state, so that we can match every future input with +past input and accordingly generate joined results. Furthermore, similar to streaming aggregations, +we automatically handle late, out-of-order data and can limit the state using watermarks. +Let’s discuss the different types of supported stream-stream joins and how to use them. + +##### Inner Joins with optional Watermarking +Inner joins on any kind of columns along with any kind of join conditions are supported. +However, as the stream runs, the size of streaming state will keep growing indefinitely as +*all* past input must be saved as the any new input can match with any input from the past. +To avoid unbounded state, you have to define additional join conditions such that indefinitely +old inputs cannot match with future inputs and therefore can be cleared from the state. +In other words, you will have to do the following additional steps in the join. + +1. Define watermark delays on both inputs such that the engine knows how delayed the input can be +(similar to streaming aggregations) + +1. Define a constraint on event-time across the two inputs such that the engine can figure out when +old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for +matches with the other input. This constraint can be defined in one of the two ways. + + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + + 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). + +Let’s understand this with an example. + +Let’s say we want to join a stream of advertisement impressions (when an ad was shown) with +another stream of user clicks on advertisements to correlate when impressions led to +monetizable clicks. To allow the state cleanup in this stream-stream join, you will have to +specify the watermarking delays and the time constraints as follows. + +1. Watermark delays: Say, the impressions and the corresponding clicks can be late/out-of-order +in event-time by at most 2 and 3 hours, respectively. + +1. Event-time range condition: Say, a click can occur within a time range of 0 seconds to 1 hour +after the corresponding impression. + +The code would look like this. + +
    +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.expr + +val impressions = spark.readStream. ... +val clicks = spark.readStream. ... + +// Apply watermarks on event-time columns +val impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +val clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.expr + +Dataset impressions = spark.readStream(). ... +Dataset clicks = spark.readStream(). ... + +// Apply watermarks on event-time columns +Dataset impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours"); +Dataset clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours"); + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour ") +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +from pyspark.sql.functions import expr + +impressions = spark.readStream. ... +clicks = spark.readStream. ... + +# Apply watermarks on event-time columns +impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +# Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +##### Outer Joins with Watermarking +While the watermark + event-time constraints is optional for inner joins, for left and right outer +joins they must be specified. This is because for generating the NULL results in outer join, the +engine must know when an input row is not going to match with anything in future. Hence, the +watermark + event-time constraints must be specified for generating correct results. Therefore, +a query with outer-join will look quite like the ad-monetization example earlier, except that +there will be an additional parameter specifying it to be an outer-join. + +
    +
    + +{% highlight scala %} + +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + joinType = "leftOuter" // can be "inner", "leftOuter", "rightOuter" + ) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour "), + "leftOuter" // can be "inner", "leftOuter", "rightOuter" +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + "leftOuter" # can be "inner", "leftOuter", "rightOuter" +) + +{% endhighlight %} + +
    +
    + +However, note that the outer NULL results will be generated with a delay (depends on the specified +watermark delay and the time range condition) because the engine has to wait for that long to ensure +there were no matches and there will be no more matches in future. + +##### Support matrix for joins in streaming queries + +
    spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo.
    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 740f12e7d13d4..bf59152c8c0cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the // stage or if the stage information has been garbage collected - store.stageData(stageId).lastOption.getOrElse { + store.asOption(store.lastStageAttempt(stageId)).getOrElse { new v1.StageData( v1.StageStatus.PENDING, stageId, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 11a6a34344976..7c6e06cf183ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} @@ -29,15 +30,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Utils /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import ApiHelper._ - import StagePage._ private val TIMELINE_LEGEND = {
    @@ -67,17 +67,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { - val localities = taskList.map(_.taskLocality) - val localityCounts = localities.groupBy(identity).mapValues(_.size) + private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = { val names = Map( TaskLocality.PROCESS_LOCAL.toString() -> "Process local", TaskLocality.NODE_LOCAL.toString() -> "Node local", TaskLocality.RACK_LOCAL.toString() -> "Rack local", TaskLocality.ANY.toString() -> "Any") - val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - s"${names(locality)}: $count" - } + val localityNamesAndCounts = names.flatMap { case (key, name) => + localitySummary.get(key).map { count => + s"$name: $count" + } + }.toSeq localityNamesAndCounts.sorted.mkString("; ") } @@ -108,7 +108,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" val stageData = parent.store - .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content =
    @@ -117,8 +117,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } - val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq - if (tasks.isEmpty) { + val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) + + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + if (totalTasks == 0) { val content =

    Summary Metrics

    No tasks have started yet @@ -127,18 +130,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } + val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { + val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${tasks.size}" + s"$totalTasks, showing ${storedTasks}" } - val externalAccumulables = stageData.accumulatorUpdates - val hasAccumulators = externalAccumulables.size > 0 - val summary =
      @@ -148,7 +147,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    • Locality Level Summary: - {getLocalitySummaryString(stageData, tasks)} + {getLocalitySummaryString(localitySummary)}
    • {if (hasInput(stageData)) {
    • @@ -266,7 +265,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - externalAccumulables.toSeq) + stageData.accumulatorUpdates.toSeq) val page: Int = { // If the user has changed to a larger page size, then go to page 1 in order to avoid @@ -280,16 +279,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( - parent.conf, + stageData, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", - tasks, - hasAccumulators, - hasInput(stageData), - hasOutput(stageData), - hasShuffleRead(stageData), - hasShuffleWrite(stageData), - hasBytesSpilled(stageData), currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -320,217 +312,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We | } |}); """.stripMargin - } + } } - val taskIdsInPage = if (taskTable == null) Set.empty[Long] - else taskTable.dataSource.slicedTaskIds + val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, + Array(0, 0.25, 0.5, 0.75, 1.0)) - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } else { - def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { - Distribution(data).get.getQuantiles() - } - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - getDistributionQuantiles(times).map { millis => -
    {UIUtils.formatDuration(millis.toLong)}{Utils.bytesToString(d.toLong)}{UIUtils.formatDuration(millis.toLong)} - - Task Deserialization Time - - Duration - GC Time - - - - Result Serialization Time - - - - Getting Result Time - - - - Peak Execution Memory - - {Utils.bytesToString(size.toLong)}Scheduler Delay{s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"}{s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"}Input Size / Records + + {title} + + Output Size / Records{title} - - Shuffle Read Blocked Time - - - - Shuffle Read Size / Records - - - - Shuffle Remote Reads - - Shuffle Write Size / RecordsShuffle spill (memory)Shuffle spill (disk)
    {task.index} {task.taskId} {UIUtils.formatDate(new Date(task.launchTime))}{task.formatDuration}{UIUtils.formatDate(task.launchTime)}{formatDuration(task.duration)} - {UIUtils.formatDuration(task.schedulerDelay)} + {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} - {UIUtils.formatDuration(task.taskDeserializationTime)} + {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))} - {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)} - {UIUtils.formatDuration(task.serializationTime)} + {formatDuration(task.taskMetrics.map(_.resultSerializationTime))} - {UIUtils.formatDuration(task.gettingResultTime)} + {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))} - {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} {Unparsed(task.accumulators.get)}{task.input.get.inputReadable}{bytesRead} / {records}{task.output.get.outputReadable}{bytesWritten} / {records} - {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))} {task.shuffleRead.get.shuffleReadReadable}{ + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics)) + val records = m.shuffleReadMetrics.recordsRead + Unparsed(s"$bytesRead / $records") + } + } - {task.shuffleRead.get.shuffleReadRemoteReadable} + {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))} {task.shuffleWrite.get.writeTimeReadable}{task.shuffleWrite.get.shuffleWriteReadable}{ + formatDuration( + task.taskMetrics.map { m => + TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime) + }, + hideZero = true) + }{ + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten) + val records = m.shuffleWriteMetrics.recordsWritten + Unparsed(s"$bytesWritten / $records") + } + }{task.bytesSpilled.get.memoryBytesSpilledReadable}{task.bytesSpilled.get.diskBytesSpilledReadable}{formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))}{formatBytes(task.taskMetrics.map(_.diskBytesSpilled))}
    $peakExecutionMemory.0 b
    spark.kubernetes.driver.container.imagespark.kubernetes.container.image (none) - Container image to use for the driver. - This is usually of the form example.com/repo/spark-driver:v1.0.0. - This configuration is required and must be provided by the user. + Container image to use for the Spark application. + This is usually of the form example.com/repo/spark:v1.0.0. + This configuration is required and must be provided by the user, unless explicit + images are provided for each different container type. +
    spark.kubernetes.driver.container.image(value of spark.kubernetes.container.image) + Custom container image to use for the driver.
    spark.kubernetes.executor.container.image(none)(value of spark.kubernetes.container.image) - Container image to use for the executors. - This is usually of the form example.com/repo/spark-executor:v1.0.0. - This configuration is required and must be provided by the user. + Custom container image to use for executors.
    spark.kubernetes.initContainer.image(none)(value of spark.kubernetes.container.image) - Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + Custom container image for the init container of both driver and executors.
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Left InputRight InputJoin Type
    StaticStaticAll types + Supported, since its not on streaming data even though it + can be present in a streaming query +
    StreamStaticInnerSupported, not stateful
    Left OuterSupported, not stateful
    Right OuterNot supported
    Full OuterNot supported
    StaticStreamInnerSupported, not stateful
    Left OuterNot supported
    Right OuterSupported, not stateful
    Full OuterNot supported
    StreamStreamInner + Supported, optionally specify watermark on both sides + + time constraints for state cleanup +
    Left Outer + Conditionally supported, must specify watermark on right + time constraints for correct + results, optionally specify watermark on left for all state cleanup +
    Right Outer + Conditionally supported, must specify watermark on left + time constraints for correct + results, optionally specify watermark on right for all state cleanup +
    Full OuterNot supported
    + +Additional details on supported joins: + +- Joins can be cascaded, that is, you can do `df1.join(df2, ...).join(df3, ...).join(df4, ....)`. + +- As of Spark 2.3, you can use joins only when the query is in Append output mode. Other output modes are not yet supported. + +- As of Spark 2.3, you cannot use other non-map-like operations before joins. Here are a few examples of + what cannot be used. + + - Cannot use streaming aggregations before joins. + + - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. + + ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1160,15 +1466,9 @@ Some of them are as follows. - Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. -- Outer joins between a streaming and a static Datasets are conditionally supported. - - + Full outer join with a streaming Dataset is not supported - - + Left outer join with a streaming Dataset on the right is not supported - - + Right outer join with a streaming Dataset on the left is not supported - -- Any kind of joins between two streaming Datasets is not yet supported. +- Few types of outer joins on streaming Datasets are not supported. See the + support matrix in the Join Operations section + for more details. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -1276,6 +1576,15 @@ Here is the compatibility matrix. Aggregations not allowed after flatMapGroupsWithState. + + Queries with joins + Append + + Update and Complete mode not supported yet. See the + support matrix in the Join Operations section + for more details on what types of joins are supported. + + Other queries Append, Update @@ -2142,6 +2451,11 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat **Talks** -- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) -- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) +- Spark Summit Europe 2017 + - Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark - + [Part 1 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark), [Part 2 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark-continues) + - Deep Dive into Stateful Stream Processing in Structured Streaming - [slides/video](https://databricks.com/session/deep-dive-into-stateful-stream-processing-in-structured-streaming) +- Spark Summit 2016 + - A Deep Dive into Structured Streaming - [slides/video](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + From 02194702068291b3af77486d01029fb848c36d7b Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Wed, 17 Jan 2018 16:42:38 -0800 Subject: [PATCH 680/712] [SPARK-21996][SQL] read files with space in name for streaming ## What changes were proposed in this pull request? Structured streaming is now able to read files with space in file name (previously it would skip the file and output a warning) ## How was this patch tested? Added new unit test. Author: Xiayun Sun Closes #19247 from xysun/SPARK-21996. --- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 50 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db84757..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -166,7 +166,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.path), + paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 39bb572740617..5bb0f4d643bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -74,11 +74,11 @@ abstract class FileStreamSourceTest protected def addData(source: FileStreamSource): Unit } - case class AddTextFileData(content: String, src: File, tmp: File) + case class AddTextFileData(content: String, src: File, tmp: File, tmpFilePrefix: String = "text") extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val tempFile = Utils.tempFileWith(new File(tmp, tmpFilePrefix)) val finalFile = new File(src, tempFile.getName) src.mkdirs() require(stringToFile(tempFile, content).renameTo(finalFile)) @@ -408,6 +408,52 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-21996 read from text files -- file name has space") { + withTempDirs { case (src, tmp) => + val textStream = createFileStream("text", src.getCanonicalPath) + val filtered = textStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp, "text text"), + CheckAnswer("keep2", "keep3") + ) + } + } + + test("SPARK-21996 read from text files generated by file sink -- file name has space") { + val testTableName = "FileStreamSourceTest" + withTable(testTableName) { + withTempDirs { case (src, checkpoint) => + val output = new File(src, "text text") + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val query = ds.writeStream + .option("checkpointLocation", checkpoint.getCanonicalPath) + .format("text") + .start(output.getCanonicalPath) + + try { + inputData.addData("foo") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + query.stop() + } + + val df2 = spark.readStream.format("text").load(output.getCanonicalPath) + val query2 = df2.writeStream.format("memory").queryName(testTableName).start() + try { + query2.processAllAvailable() + checkDatasetUnorderly(spark.table(testTableName).as[String], "foo") + } finally { + query2.stop() + } + } + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) From 39d244d921d8d2d3ed741e8e8f1175515a74bdbd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 14:51:05 +0900 Subject: [PATCH 681/712] [SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark ## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon Closes #20288 from HyukjinKwon/deprecate-udf. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 91 ++-------------- python/pyspark/sql/context.py | 137 ++++-------------------- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 3 +- python/pyspark/sql/session.py | 6 +- python/pyspark/sql/tests.py | 20 ++++ python/pyspark/sql/udf.py | 182 +++++++++++++++++++++++++++++++- 8 files changed, 234 insertions(+), 210 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 35fbe9e669adb..6aef0f22340be 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. - # This is to check whether the input function is a wrapped/native UserDefinedFunction - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: None is expected when f is a UserDefinedFunction, " - "but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85479095af594..cc1cd1a5842d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,9 +29,10 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -147,7 +148,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = sqlContext.udf.register("slen", slen) - >>> sqlContext.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) - >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP - >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) - @ignore_unicode_prefix @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - @since(2.3) - def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -590,24 +507,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=None): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f7b3f29764040..988c1d25259bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): ... return pd.Series(np.random.randn(len(v)) >>> random = random.asNondeterministic() # doctest: +SKIP - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf014..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -212,7 +212,8 @@ def apply(self, udf): This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a group map user-defined function returned by + :meth:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 604021c1f45cc..6c84023c43fb6 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -29,7 +29,6 @@ from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -280,6 +279,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +291,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + from pyspark.sql.udf import UDFRegistration + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8906618666b14..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -372,6 +372,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..1943bb73f9ac2 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,179 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. This instance can be accessed by + :attr:`spark.udf` or :attr:`sqlContext.udf`. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since("1.3.1") + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. Please see below. + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 1c76a91e5fae11dcb66c453889e587b48039fdc9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 22:36:29 -0800 Subject: [PATCH 682/712] [SPARK-23052][SS] Migrate ConsoleSink to data source V2 api. ## What changes were proposed in this pull request? Migrate ConsoleSink to data source V2 api. Note that this includes a missing piece in DataStreamWriter required to specify a data source V2 writer. Note also that I've removed the "Rerun batch" part of the sink, because as far as I can tell this would never have actually happened. A MicroBatchExecution object will only commit each batch once for its lifetime, and a new MicroBatchExecution object would have a new ConsoleSink object which doesn't know it's retrying a batch. So I think this represents an anti-feature rather than a weakness in the V2 API. ## How was this patch tested? new unit test Author: Jose Torres Closes #20243 from jose-torres/console-sink. --- .../streaming/MicroBatchExecution.scala | 7 +- .../sql/execution/streaming/console.scala | 62 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 64 +++++ .../sources/PackedRowWriterFactory.scala | 60 +++++ .../sql/streaming/DataStreamWriter.scala | 16 +- ...pache.spark.sql.sources.DataSourceRegister | 8 + .../sources/ConsoleWriterSuite.scala | 135 ++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 249 ++++++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 25 -- 10 files changed, 551 insertions(+), 84 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70407f0580f97..7c3804547b736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -91,11 +91,14 @@ class MicroBatchExecution( nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, _, _, output, v1Relation) => + case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + if (v1Relation.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support microbatch processing.") + } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 StreamingExecutionRelation(source, output)(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 71eaabe273fea..94820376ff7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,58 +17,36 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -class ConsoleSink(options: Map[String, String]) extends Sink with Logging { - // Number of rows to display, by default 20 rows - private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) - - // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) +import java.util.Optional - // Track the batch id - private var lastBatchId = -1L - - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - val batchIdStr = if (batchId <= lastBatchId) { - s"Rerun batch: $batchId" - } else { - lastBatchId = batchId - s"Batch: $batchId" - } - - // scalastyle:off println - println("-------------------------------------------") - println(batchIdStr) - println("-------------------------------------------") - // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } +import scala.collection.JavaConverters._ - override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" -} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends StreamSinkProvider +class ConsoleSinkProvider extends DataSourceV2 + with MicroBatchWriteSupport with DataSourceRegister with CreatableRelationProvider { - def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { - new ConsoleSink(parameters) + + override def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + Optional.of(new ConsoleWriter(epochId, schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c0507224f9be8..462e7d9721d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -54,16 +54,13 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ - override lazy val logicalPlan: LogicalPlan = { - assert(queryExecutionThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") + override val logicalPlan: LogicalPlan = { val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( @@ -72,7 +69,7 @@ class ContinuousExecution( ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) case StreamingRelationV2(_, sourceName, _, _, _) => - throw new AnalysisException( + throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 0000000000000..361979984bbec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +/** + * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. + * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) + extends DataSourceV2Writer with Logging { + // Number of rows to display, by default 20 rows + private val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + private val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + + override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { + val batch = messages.collect { + case PackedRowCommitMessage(rows) => rows + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(s"Batch: $batchId") + println("-------------------------------------------") + // scalastyle:off println + spark.createDataFrame( + spark.sparkContext.parallelize(batch), schema) + .show(numRowsToShow, isTruncated) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala new file mode 100644 index 0000000000000..9282ba05bdb7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} + +/** + * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery + * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * + * Note that, because it sends all rows to the driver, this factory will generally be unsuitable + * for production-quality sinks. It's intended for use in tests. + */ +case object PackedRowWriterFactory extends DataWriterFactory[Row] { + def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + new PackedRowDataWriter() + } +} + +/** + * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most + * recent interval. + */ +case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage + +/** + * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. + */ +class PackedRowDataWriter() extends DataWriter[Row] with Logging { + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = data.append(row) + + override def commit(): PackedRowCommitMessage = { + val msg = PackedRowCommitMessage(data.toArray) + data.clear() + msg + } + + override def abort(): Unit = data.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..d24f0ddeab4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,14 +280,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val sink = (ds.newInstance(), trigger) match { + case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w + case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( + s"Data source $source does not support continuous writing") + case (w: MicroBatchWriteSupport, _) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c6973bf41d34b..a0b25b4e82364 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,11 @@ org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo org.apache.fakesource.FakeExternalSourceThree +org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly +org.apache.spark.sql.streaming.sources.FakeReadBothModes +org.apache.spark.sql.streaming.sources.FakeReadNeitherMode +org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly +org.apache.spark.sql.streaming.sources.FakeWriteBothModes +org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 0000000000000..60ffee9b9b42c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("console") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("console with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("console with truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala new file mode 100644 index 0000000000000..f152174b0a7f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sources + +import java.util.Optional + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setOffset(start: Optional[Offset]): Unit = {} + + def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + throw new IllegalStateException("fake source - cannot actually read") + } +} + +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = FakeReader() +} + +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): ContinuousReader = FakeReader() +} + +trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { + def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +trait FakeContinuousWriteSupport extends ContinuousWriteSupport { + def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { + override def shortName(): String = "fake-read-microbatch-only" +} + +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-continuous-only" +} + +class FakeReadBothModes extends DataSourceRegister + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-microbatch-continuous" +} + +class FakeReadNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-read-neither-mode" +} + +class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { + override def shortName(): String = "fake-write-microbatch-only" +} + +class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-continuous-only" +} + +class FakeWriteBothModes extends DataSourceRegister + with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" +} + +class FakeWriteNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" +} + +class StreamingDataSourceV2Suite extends StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + val fakeCheckpoint = Utils.createTempDir() + spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + } + + val readFormats = Seq( + "fake-read-microbatch-only", + "fake-read-continuous-only", + "fake-read-microbatch-continuous", + "fake-read-neither-mode") + val writeFormats = Seq( + "fake-write-microbatch-only", + "fake-write-continuous-only", + "fake-write-microbatch-continuous", + "fake-write-neither-mode") + val triggers = Seq( + Trigger.Once(), + Trigger.ProcessingTime(1000), + Trigger.Continuous(1000)) + + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + query.stop() + } + + private def testNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val ex = intercept[UnsupportedOperationException] { + testPositiveCase(readFormat, writeFormat, trigger) + } + assert(ex.getMessage.contains(errorMsg)) + } + + private def testPostCreationNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + + eventually(timeout(streamingTimeout)) { + assert(query.exception.isDefined) + assert(query.exception.get.cause != null) + assert(query.exception.get.cause.getMessage.contains(errorMsg)) + } + } + + // Get a list of (read, write, trigger) tuples for test cases. + val cases = readFormats.flatMap { read => + writeFormats.flatMap { write => + triggers.map(t => (write, t)) + }.map { + case (write, t) => (read, write, t) + } + } + + for ((read, write, trigger) <- cases) { + testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + (readSource, writeSource, trigger) match { + // Valid microbatch queries. + case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + if !t.isInstanceOf[ContinuousTrigger] => + testPositiveCase(read, write, trigger) + + // Valid continuous queries. + case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + testPositiveCase(read, write, trigger) + + // Invalid - can't read at all + case (r, _, _) + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support streamed reading") + + // Invalid - trigger is continuous but writer is not + case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support continuous writing") + + // Invalid - can't write at all + case (_, w, _) + if !w.isInstanceOf[MicroBatchWriteSupport] + && !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are continuous but reader is not + case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support continuous processing") + + // Invalid - trigger is microbatch but writer is not + case (_, w, t) + if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are microbatch but reader is not + case (r, _, t) + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..8212fb912ec57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .start() - - sq.awaitTermination(2000L) - } - test("prevent all column partitioning") { withTempDir { dir => val path = dir.getCanonicalPath @@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink should not require checkpointLocation") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream.format("console").start() - sq.stop() - } - private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) From 7a2248341396840628eef398aa512cac3e3bd55f Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 19:18:55 +0800 Subject: [PATCH 683/712] [SPARK-23140][SQL] Add DataSourceV2Strategy to Hive Session state's planner ## What changes were proposed in this pull request? `DataSourceV2Strategy` is missing in `HiveSessionStateBuilder`'s planner, which will throw exception as described in [SPARK-23140](https://issues.apache.org/jira/browse/SPARK-23140). ## How was this patch tested? Manual test. Author: jerryshao Closes #20305 from jerryshao/SPARK-23140. --- .../sql/hive/HiveSessionStateBuilder.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dc92ad3b0c1ac..12c74368dd184 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -96,22 +96,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ - extraPlanningStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy(conf), - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } + super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) } } From e28eb431146bcdcaf02a6f6c406ca30920592a6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 18 Jan 2018 21:24:39 +0800 Subject: [PATCH 684/712] [SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL ## What changes were proposed in this pull request? When there is an operation between Decimals and the result is a number which is not representable exactly with the result's precision and scale, Spark is returning `NULL`. This was done to reflect Hive's behavior, but it is against SQL ANSI 2011, which states that "If the result cannot be represented exactly in the result type, then whether it is rounded or truncated is implementation-defined". Moreover, Hive now changed its behavior in order to respect the standard, thanks to HIVE-15331. Therefore, the PR propose to: - update the rules to determine the result precision and scale according to the new Hive's ones introduces in HIVE-15331; - round the result of the operations, when it is not representable exactly with the result's precision and scale, instead of returning `NULL` - introduce a new config `spark.sql.decimalOperations.allowPrecisionLoss` which default to `true` (ie. the new behavior) in order to allow users to switch back to the previous one. Hive behavior reflects SQLServer's one. The only difference is that the precision and scale are adjusted for all the arithmetic operations in Hive, while SQL Server is said to do so only for multiplications and divisions in the documentation. This PR follows Hive's behavior. A more detailed explanation is available here: https://mail-archives.apache.org/mod_mbox/spark-dev/201712.mbox/%3CCAEorWNAJ4TxJR9NBcgSFMD_VxTg8qVxusjP%2BAJP-x%2BJV9zH-yA%40mail.gmail.com%3E. ## How was this patch tested? modified and added UTs. Comparisons with results of Hive and SQLServer. Author: Marco Gaido Closes #20023 from mgaido91/SPARK-22036. --- docs/sql-programming-guide.md | 5 + .../catalyst/analysis/DecimalPrecision.scala | 114 +++++--- .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 12 + .../apache/spark/sql/types/DecimalType.scala | 45 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 20 +- .../native/decimalArithmeticOperations.sql | 47 ++++ .../decimalArithmeticOperations.sql.out | 245 ++++++++++++++++-- .../native/decimalPrecision.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 -- 11 files changed, 434 insertions(+), 82 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 258c769ff593b..3e2e48a0ef249 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1793,6 +1793,11 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aac..ab63131b07573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) } - val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** @@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..cd176d941819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16fbb0c3e9e21..cc4f4bf332459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,6 +1064,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .booleanConf + .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + @@ -1441,6 +1451,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8acb..ef3b67c0d48d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + /** + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. + * + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + */ + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + assert(precision >= scale) + assert(scale >= 0) + + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae0..cd8579584eada 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45e..c6d8a49d4b93a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,51 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -31,3 +76,5 @@ select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL select 123456789123456789.1234567890 * 1.123456789123456789; select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456c..4d70fe19d539f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 32 -- !query 0 @@ -35,48 +35,257 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -NULL +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output +NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 17 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 17 output +138698367904130467.654320988515622621 + + +-- !query 18 +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'spark' expecting (line 3, pos 4) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +----^^^ + + +-- !query 19 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 19 schema +struct +-- !query 19 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 20 +select id, a*10, b/10 from decimals_test order by id +-- !query 20 schema +struct +-- !query 20 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 21 +select 10.3 * 3.0 +-- !query 21 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 21 output +30.9 + + +-- !query 22 +select 10.3000 * 3.0 +-- !query 22 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 22 output +30.9 + + +-- !query 23 +select 10.30000 * 30.0 +-- !query 23 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 23 output +309 + + +-- !query 24 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 24 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 25 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 25 output +30.9 + + +-- !query 26 +select (5e36 + 0.1) + 5e36 +-- !query 26 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 26 output +NULL + + +-- !query 27 +select (-4e36 - 0.1) - 7e36 +-- !query 27 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 27 output +NULL + + +-- !query 28 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 28 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 28 output +NULL + + +-- !query 29 +select 1e35 / 0.1 +-- !query 29 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 29 output +NULL + + +-- !query 30 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 30 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 30 output +138698367904130467.654320988515622621 + + +-- !query 31 +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'table' expecting (line 3, pos 5) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-----^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1d..6ee7f59d69877 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d4d0aa4f5f5eb..083a0c0b1b9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1517,24 +1517,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") From 5063b7481173ad72bd0dc941b5cf3c9b26a591e4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 18 Jan 2018 22:33:04 +0900 Subject: [PATCH 685/712] [SPARK-23141][SQL][PYSPARK] Support data type string as a returnType for registerJavaFunction. ## What changes were proposed in this pull request? Currently `UDFRegistration.registerJavaFunction` doesn't support data type string as a `returnType` whereas `UDFRegistration.register`, `udf`, or `pandas_udf` does. We can support it for `UDFRegistration.registerJavaFunction` as well. ## How was this patch tested? Added a doctest and existing tests. Author: Takuya UESHIN Closes #20307 from ueshin/issues/SPARK-23141. --- python/pyspark/sql/functions.py | 6 ++++-- python/pyspark/sql/udf.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 988c1d25259bc..961b3267b44cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()): can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 1943bb73f9ac2..c77f19f89a442 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -206,7 +206,8 @@ def register(self, name, f, returnType=None): :param f: a Python function, or a user-defined function. The user-defined function can be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. + :param returnType: the return type of the registered user-defined function. The value can + be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. `returnType` can be optionally specified when `f` is a Python function but not @@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :param name: name of the user-defined function :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the registered Java function. The value can be either + a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> spark.udf.registerJavaFunction( ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + >>> spark.sql("SELECT javaStringLength3('test')").collect() + [Row(UDF:javaStringLength3(test)=4)] """ jdt = None if returnType is not None: + if not isinstance(returnType, DataType): + returnType = _parse_datatype_string(returnType) jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) From cf7ee1767ddadce08dce050fc3b40c77cdd187da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 10:19:36 -0800 Subject: [PATCH 686/712] [SPARK-23147][UI] Fix task page table IndexOutOfBound Exception ## What changes were proposed in this pull request? Stage's task page table will throw an exception when there's no complete tasks. Furthermore, because the `dataSize` doesn't take running tasks into account, so sometimes UI cannot show the running tasks. Besides table will only be displayed when first task is finished according to the default sortColumn("index"). ![screen shot 2018-01-18 at 8 50 08 pm](https://user-images.githubusercontent.com/850797/35100052-470b4cae-fc95-11e7-96a2-ad9636e732b3.png) To reproduce this issue, user could try `sc.parallelize(1 to 20, 20).map { i => Thread.sleep(10000); i }.collect()` or `sc.parallelize(1 to 20, 20).map { i => Thread.sleep((20 - i) * 1000); i }.collect` to reproduce the above issue. Here propose a solution to fix it. Not sure if it is a right fix, please help to review. ## How was this patch tested? Manual test. Author: jerryshao Closes #20315 from jerryshao/SPARK-23147. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7c6e06cf183ba..af78373ddb4b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -676,7 +676,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks + override def dataSize: Int = stage.numTasks override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { From 9678941f54ebc5db935ed8d694e502086e2a31c0 Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Thu, 18 Jan 2018 13:02:03 -0600 Subject: [PATCH 687/712] [SPARK-23029][DOCS] Specifying default units of configuration entries ## What changes were proposed in this pull request? This PR completes the docs, specifying the default units assumed in configuration entries of type size. This is crucial since unit-less values are accepted and the user might assume the base unit is bytes, which in most cases it is not, leading to hard-to-debug problems. ## How was this patch tested? This patch updates only documentation only. Author: Fernando Pereira Closes #20269 from ferdonline/docs_units. --- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../spark/internal/config/package.scala | 47 ++++---- docs/configuration.md | 100 ++++++++++-------- 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d77303e6fdf8b..f53b2bed74c6e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -640,9 +640,9 @@ private[spark] object SparkConf extends Logging { translation = s => s"${s.toLong * 10}s")), "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), - "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb12ddf961314..bbfcfbaa7363c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -38,10 +38,13 @@ package object config { ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per driver in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -62,6 +65,7 @@ package object config { .createWithDefault(false) private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .doc("Buffer size to use when writing to output streams, in KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") @@ -81,10 +85,13 @@ package object config { ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per executor in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -353,7 +360,7 @@ package object config { private[spark] val BUFFER_WRITE_CHUNK_SIZE = ConfigBuilder("spark.buffer.write.chunkSize") .internal() - .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + " ChunkedByteBuffer should not larger than Int.MaxValue.") @@ -368,9 +375,9 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + - "record the size accurately if it's above this config. This helps to prevent OOM by " + - "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) @@ -389,23 +396,23 @@ package object config { private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") - .doc("This configuration limits the number of remote blocks being fetched per reduce task" + - " from a given host port. When a large number of blocks are being requested from a given" + - " address in a single fetch or simultaneously, this could crash the serving executor or" + - " Node Manager. This is especially useful to reduce the load on the Node Manager when" + - " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .doc("This configuration limits the number of remote blocks being fetched per reduce task " + + "from a given host port. When a large number of blocks are being requested from a given " + + "address in a single fetch or simultaneously, this could crash the serving executor or " + + "Node Manager. This is especially useful to reduce the load on the Node Manager when " + + "external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") .intConf .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") - .doc("Remote block will be fetched to disk when size of the block is " + - "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + - "affect both shuffle fetch and block manager remote block fetch. For users who " + - "enabled external shuffle service, this feature can only be worked when external shuffle" + - " service is newer than Spark 2.2.") + .doc("Remote block will be fetched to disk when size of the block is above this threshold " + + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + + "both shuffle fetch and block manager remote block fetch. For users who enabled " + + "external shuffle service, this feature can only be worked when external shuffle" + + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -419,9 +426,9 @@ package object config { private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") - .doc("Size of the in-memory buffer for each shuffle file output stream. " + - "These buffers reduce the number of disk seeks and system calls made " + - "in creating intermediate shuffle files.") + .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + + "otherwise specified. These buffers reduce the number of disk seeks and system calls " + + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -430,7 +437,7 @@ package object config { private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("The file system for this buffer size after each partition " + - "is written in unsafe shuffle writer.") + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -438,7 +445,7 @@ package object config { private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") - .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") diff --git a/docs/configuration.md b/docs/configuration.md index 1189aea2aa71f..eecb39dcafc9e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -58,6 +58,10 @@ The following format is accepted: 1t or 1tb (tebibytes = 1024 gibibytes) 1p or 1pb (pebibytes = 1024 tebibytes) +While numbers without units are generally interpreted as bytes, a few are interpreted as KiB or MiB. +See documentation of individual configuration properties. Specifying units is desirable where +possible. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For @@ -136,9 +140,9 @@ of the most common options to set are: spark.driver.maxResultSize 1g - Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). - Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size - is above this limit. + Limit of total size of serialized results of all partitions for each Spark action (e.g. + collect) in bytes. Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total + size is above this limit. Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory and memory overhead of objects in JVM). Setting a proper limit can protect the driver from out-of-memory errors. @@ -148,10 +152,10 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
    Note: In client mode, this config must not be set through the SparkConf + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB + unless otherwise specified (e.g. 1g, 2g). +
    + Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option or in your default properties file. @@ -161,27 +165,28 @@ of the most common options to set are: spark.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is - memory that accounts for things like VM overheads, interned strings, other native overheads, etc. - This tends to grow with the container size (typically 6-10%). This option is currently supported - on YARN and Kubernetes. + The amount of off-heap memory to be allocated per driver in cluster mode, in MiB unless + otherwise specified. This is memory that accounts for things like VM overheads, interned strings, + other native overheads, etc. This tends to grow with the container size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. spark.executor.memory 1g - Amount of memory to use per executor process (e.g. 2g, 8g). + Amount of memory to use per executor process, in MiB unless otherwise specified. + (e.g. 2g, 8g). spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that - accounts for things like VM overheads, interned strings, other native overheads, etc. This tends - to grow with the executor size (typically 6-10%). This option is currently supported on YARN and - Kubernetes. + The amount of off-heap memory to be allocated per executor, in MiB unless otherwise specified. + This is memory that accounts for things like VM overheads, interned strings, other native + overheads, etc. This tends to grow with the executor size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. @@ -431,8 +436,9 @@ Apart from these, the following properties are also available, and may be useful 512m Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g). + If the memory used during aggregation goes above this amount, it will spill the data into disks. @@ -540,9 +546,10 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxSizeInFlight 48m - Maximum size of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + Maximum size of map outputs to fetch simultaneously from each reduce task, in MiB unless + otherwise specified. Since each output requires us to create a buffer to receive it, this + represents a fixed memory overhead per reduce task, so keep it small unless you have a + large amount of memory. @@ -570,9 +577,9 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The remote block will be fetched to disk when size of the block is above this threshold. + The remote block will be fetched to disk when size of the block is above this threshold in bytes. This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -589,8 +596,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer 32k - Size of the in-memory buffer for each shuffle file output stream. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + Size of the in-memory buffer for each shuffle file output stream, in KiB unless otherwise + specified. These buffers reduce the number of disk seeks and system calls made in creating + intermediate shuffle files. @@ -651,7 +659,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.index.cache.size 100m - Cache entries limited to the specified memory footprint. + Cache entries limited to the specified memory footprint in bytes. @@ -685,9 +693,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.accurateBlockThreshold 100 * 1024 * 1024 - When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the - size accurately if it's above this config. This helps to prevent OOM by avoiding - underestimating shuffle block size when fetch shuffle blocks. + Threshold in bytes above which the size of shuffle blocks in HighlyCompressedMapStatus is + accurately recorded. This helps to prevent OOM by avoiding underestimating shuffle + block size when fetch shuffle blocks. @@ -779,7 +787,7 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.buffer.kb 100k - Buffer size in KB to use when writing to output streams. + Buffer size to use when writing to output streams, in KiB unless otherwise specified. @@ -917,7 +925,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.lz4.blockSize 32k - Block size used in LZ4 compression, in the case when LZ4 compression codec + Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -925,7 +933,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.snappy.blockSize 32k - Block size used in Snappy compression, in the case when Snappy compression codec + Block size in bytes used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. @@ -941,7 +949,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.zstd.bufferSize 32k - Buffer size used in Zstd compression, in the case when Zstd compression codec + Buffer size in bytes used in Zstd compression, in the case when Zstd compression codec is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it might increase the compression cost because of excessive JNI call overhead. @@ -1001,8 +1009,8 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.max 64m - Maximum allowable size of Kryo serialization buffer. This must be larger than any - object you attempt to serialize and must be less than 2048m. + Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. + This must be larger than any object you attempt to serialize and must be less than 2048m. Increase this if you get a "buffer limit exceeded" exception inside Kryo. @@ -1010,9 +1018,9 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer 64k - Initial size of Kryo's serialization buffer. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max if needed. + Initial size of Kryo's serialization buffer, in KiB unless otherwise specified. + Note that there will be one buffer per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max if needed. @@ -1086,7 +1094,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory + use is enabled, then spark.memory.offHeap.size must be positive. @@ -1094,7 +1103,8 @@ Apart from these, the following properties are also available, and may be useful 0 The absolute amount of memory in bytes which can be used for off-heap allocation. - This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This setting has no impact on heap memory usage, so if your executors' total memory consumption + must fit within some hard limit then be sure to shrink your JVM heap size accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -1202,9 +1212,9 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4m - Size of each piece of a block for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Size of each piece of a block for TorrentBroadcastFactory, in KiB unless otherwise + specified. Too large a value decreases parallelism during broadcast (makes it slower); however, + if it is too small, BlockManager might take a performance hit. @@ -1312,7 +1322,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold 2m - Size of a block above which Spark memory maps when reading a block from disk. + Size in bytes of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. @@ -2490,4 +2500,4 @@ Also, you can modify or add configurations at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ myApp.jar -{% endhighlight %} \ No newline at end of file +{% endhighlight %} From 2d41f040a34d6483919fd5d491cf90eee5429290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:25:52 -0800 Subject: [PATCH 688/712] [SPARK-23143][SS][PYTHON] Added python API for setting continuous trigger ## What changes were proposed in this pull request? Self-explanatory. ## How was this patch tested? New python tests. Author: Tathagata Das Closes #20309 from tdas/SPARK-23143. --- python/pyspark/sql/streaming.py | 23 +++++++++++++++++++---- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24ae3776a217b..e2a97acb5e2a7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -786,7 +786,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -802,23 +802,38 @@ def trigger(self, processingTime=None, once=None): >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..25483594f2725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1538,6 +1538,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds') From bf34d665b9c865e00fac7001500bf6d521c2dff9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:33:39 -0800 Subject: [PATCH 689/712] [SPARK-23144][SS] Added console sink for continuous processing ## What changes were proposed in this pull request? Refactored ConsoleWriter into ConsoleMicrobatchWriter and ConsoleContinuousWriter. ## How was this patch tested? new unit test Author: Tathagata Das Closes #20311 from tdas/SPARK-23144. --- .../sql/execution/streaming/console.scala | 20 +++-- .../streaming/sources/ConsoleWriter.scala | 80 ++++++++++++++----- .../sources/ConsoleWriterSuite.scala | 26 +++++- 3 files changed, 96 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 94820376ff7e7..f2aa3259731d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) class ConsoleSinkProvider extends DataSourceV2 with MicroBatchWriteSupport + with ContinuousWriteSupport with DataSourceRegister with CreatableRelationProvider { override def createMicroBatchWriter( queryId: String, - epochId: Long, + batchId: Long, schema: StructType, mode: OutputMode, options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleWriter(epochId, schema, options)) + Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + Optional.of(new ConsoleContinuousWriter(schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 361979984bbec..6fb61dff60045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType -/** - * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. - * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) - extends DataSourceV2Writer with Logging { +/** Common methods used to create writes for the the console sink */ +trait ConsoleWriter extends Logging { + + def options: DataSourceV2Options + // Number of rows to display, by default 20 rows - private val numRowsToShow = options.getInt("numRows", 20) + protected val numRowsToShow = options.getInt("numRows", 20) // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.getBoolean("truncate", true) + protected val isTruncated = options.getBoolean("truncate", true) assert(SparkSession.getActiveSession.isDefined) - private val spark = SparkSession.getActiveSession.get + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { - val batch = messages.collect { + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { case PackedRowCommitMessage(rows) => rows }.flatten // scalastyle:off println println("-------------------------------------------") - println(s"Batch: $batchId") + println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark.createDataFrame( - spark.sparkContext.parallelize(batch), schema) + spark + .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } +} + + +/** + * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) + extends DataSourceV2Writer with ConsoleWriter { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Batch: $batchId") + } + + override def toString(): String = { + s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +/** + * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) + extends ContinuousWriter with ConsoleWriter { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Continuous processing epoch $epochId") + } + + override def toString(): String = { + s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 60ffee9b9b42c..55acf2ba28d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} class ConsoleWriterSuite extends StreamTest { import testImplicits._ - test("console") { + test("microbatch - default") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with numRows") { + test("microbatch - with numRows") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with truncation") { + test("microbatch - truncation") { val input = MemoryStream[String] val captured = new ByteArrayOutputStream() @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest { | |""".stripMargin) } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } } From f568e9cf76f657d094f1d036ab5a95f2531f5761 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Thu, 18 Jan 2018 14:00:12 -0800 Subject: [PATCH 690/712] [SPARK-23133][K8S] Fix passing java options to Executor Pass through spark java options to the executor in context of docker image. Closes #20296 andrusha: Deployed two version of containers to local k8s, checked that java options were present in the updated image on the running executor. Manual test Author: Andrew Korzhuev Closes #20322 from foxish/patch-1. --- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 0c28c75857871..b9090dc2852a5 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -42,7 +42,7 @@ shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" fi @@ -54,7 +54,7 @@ case "$SPARK_K8S_CMD" in driver) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_DRIVER_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY @@ -67,7 +67,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From e01919e834d301e13adc8919932796ebae900576 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 19 Jan 2018 07:36:06 +0900 Subject: [PATCH 691/712] [SPARK-23094] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? There were two related fixes regarding `from_json`, `get_json_object` and `json_tuple` ([Fix #1](https://github.com/apache/spark/commit/c8803c06854683c8761fdb3c0e4c55d5a9e22a95), [Fix #2](https://github.com/apache/spark/commit/86174ea89b39a300caaba6baffac70f3dc702788)), but they weren't comprehensive it seems. I wanted to extend those fixes to all the parsers, and add tests for each case. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #20302 from brkyvz/json-invfix. --- .../catalyst/json/CreateJacksonParser.scala | 5 +-- .../sources/JsonHadoopFsRelationSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..b1672e7e2fca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,10 +40,11 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - jsonFactory.createParser(record.getBytes, 0, record.getLength) + val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 49be30435ad2f..27f398ebf301a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -105,4 +107,36 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("invalid json with leading nulls - from file (multiLine=true)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val expected = s"""$badJson\n{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("invalid json with leading nulls - from file (multiLine=false)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("invalid json with leading nulls - from dataset") { + import testImplicits._ + checkAnswer( + spark.read.json(Seq(badJson).toDS()), + Row(badJson)) + } } From 5d7c4ba4d73a72f26d591108db3c20b4a6c84f3f Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 18 Jan 2018 14:44:22 -0800 Subject: [PATCH 692/712] [SPARK-22962][K8S] Fail fast if submission client local files are used ## What changes were proposed in this pull request? In the Kubernetes mode, fails fast in the submission process if any submission client local dependencies are used as the use case is not supported yet. ## How was this patch tested? Unit tests, integration tests, and manual tests. vanzin foxish Author: Yinan Li Closes #20320 from liyinan926/master. --- docs/running-on-kubernetes.md | 5 ++- .../k8s/submit/DriverConfigOrchestrator.scala | 14 ++++++++- .../DriverConfigOrchestratorSuite.scala | 31 ++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 08ec34c63ba3f..d6b1735ce5550 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,7 +117,10 @@ This URI is the location of the example jar that is already in the Docker image. If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the -`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. The `local://` scheme is also required when referring to +dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission +client's local file system is currently not yet supported. + ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index c9cc300d65569..ae70904621184 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -20,7 +20,7 @@ import java.util.UUID import com.google.common.primitives.Longs -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -117,6 +117,12 @@ private[spark] class DriverConfigOrchestrator( .map(_.split(",")) .getOrElse(Array.empty[String]) + // TODO(SPARK-23153): remove once submission client local dependencies are supported. + if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { + throw new SparkException("The Kubernetes mode does not yet support referencing application " + + "dependencies in the local file system.") + } + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, @@ -162,6 +168,12 @@ private[spark] class DriverConfigOrchestrator( initContainerBootstrapStep } + private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme == "file" + } + } + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { files.exists { uri => Utils.resolveURI(uri).getScheme != "local" diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 65274c6f50e01..033d303e946fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ @@ -117,6 +117,35 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[DriverMountSecretsStep]) } + test("Submission using client local dependencies") { + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + var orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + + sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") + orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + } + private def validateStepTypes( orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { From 4cd2ecc0c7222fef1337e04f1948333296c3be86 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 16:29:45 -0800 Subject: [PATCH 693/712] [SPARK-23142][SS][DOCS] Added docs for continuous processing ## What changes were proposed in this pull request? Added documentation for continuous processing. Modified two locations. - Modified the overview to have a mention of Continuous Processing. - Added a new section on Continuous Processing at the end. ![image](https://user-images.githubusercontent.com/663212/35083551-a3dd23f6-fbd4-11e7-9e7e-90866f131ca9.png) ![image](https://user-images.githubusercontent.com/663212/35083618-d844027c-fbd4-11e7-9fde-75992cc517bd.png) ## How was this patch tested? N/A Author: Tathagata Das Closes #20308 from tdas/SPARK-23142. --- .../structured-streaming-programming-guide.md | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 1779a4215e085..2ddba2f0d942e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,9 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. + +In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in @@ -2434,6 +2436,100 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat +# Continuous Processing [Experimental] +**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). + +To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, + +
    +
    +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +spark + .readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("") + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start() +{% endhighlight %} +
    +
    +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger; + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start(); +{% endhighlight %} +
    +
    +{% highlight python %} +spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .load() \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .trigger(continuous="1 second") \ # only change in query + .start() + +{% endhighlight %} +
    +
    + +A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. + +## Supported Queries +As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. + +- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). + + All SQL functions are supported except aggregation functions (since aggregations are not yet supported), `current_timestamp()` and `current_date()` (deterministic computations using time is challenging). + +- *Sources*: + + Kafka source: All options are supported. + + Rate source: Good for testing. Only options that are supported in the continuous mode are `numPartitions` and `rowsPerSecond`. + +- *Sinks*: + + Kafka sink: All options are supported. + + Memory sink: Good for debugging. + + Console sink: Good for debugging. All options are supported. Note that the console will print every checkpoint interval that you have specified in the continuous trigger. + +See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. + +## Caveats +- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. +- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. +- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. + # Additional Information **Further Reading** From 6121e91b7f5c9513d68674e4d5edbc3a4a5fd5fd Mon Sep 17 00:00:00 2001 From: brandonJY Date: Thu, 18 Jan 2018 18:57:49 -0600 Subject: [PATCH 694/712] [DOCS] change to dataset for java code in structured-streaming-kafka-integration document ## What changes were proposed in this pull request? In latest structured-streaming-kafka-integration document, Java code example for Kafka integration is using `DataFrame`, shouldn't it be changed to `DataSet`? ## How was this patch tested? manual test has been performed to test the updated example Java code in Spark 2.2.1 with Kafka 1.0 Author: brandonJY Closes #20312 from brandonJY/patch-2. --- docs/structured-streaming-kafka-integration.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index bab0be8ddeb9f..461c29ce1ba89 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -61,7 +61,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -70,7 +70,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -79,7 +79,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -171,7 +171,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -180,7 +180,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -191,7 +191,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") From 568055da93049c207bb830f244ff9b60c638837c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 19 Jan 2018 11:37:08 +0800 Subject: [PATCH 695/712] [SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String. ## What changes were proposed in this pull request? This is a follow-up of #20246. If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`. This pr fixes it by using its `sqlType` casting. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20306 from ueshin/issues/SPARK-23054/fup1. --- python/pyspark/sql/tests.py | 11 +++++++++++ .../apache/spark/sql/catalyst/expressions/Cast.scala | 2 ++ .../org/apache/spark/sql/test/ExamplePointUDT.scala | 2 ++ 3 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 25483594f2725..4fee2ecde391b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1189,6 +1189,17 @@ def test_union_with_udt(self): ] ) + def test_cast_to_string_with_udt(self): + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a95ebe301b9d1..79b051670e9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) case udt: UserDefinedType[_] => buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) @@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => val udtRef = ctx.addReferenceObj("udt", udt) (c, evPrim, evNull) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a73e4272950a4..8bab7e1c58762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false } + + override def toString(): String = s"($x, $y)" } /** From 9c4b99861cda3f9ec44ca8c1adc81a293508190c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 19 Jan 2018 01:38:08 -0800 Subject: [PATCH 696/712] [BUILD][MINOR] Fix java style check issues ## What changes were proposed in this pull request? This patch fixes a few recently introduced java style check errors in master and release branch. As an aside, given that [java linting currently fails](https://github.com/apache/spark/pull/10763 ) on machines with a clean maven cache, it'd be great to find another workaround to [re-enable the java style checks](https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L577) as part of Spark PRB. /cc zsxwing JoshRosen srowen for any suggestions ## How was this patch tested? Manual Check Author: Sameer Agarwal Closes #20323 from sameeragarwal/java. --- .../spark/sql/sources/v2/writer/DataSourceV2Writer.java | 6 ++++-- .../org/apache/spark/sql/vectorized/ArrowColumnVector.java | 5 +++-- .../apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 317ac45bcfd74..f1ef411423162 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,8 +28,10 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( + * String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index eb69001fe677e..bfd1b4cb0ef12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -556,8 +556,9 @@ final int getArrayOffset(int rowId) { /** * Any call to "get" method will throw UnsupportedOperationException. * - * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined - * in the parent class. Any call to "get" method in this class is a bug in the code. + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. * */ private static class StructAccessor extends ArrowVectorAccessor { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 44e5146d7c553..98d6a53b54d28 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,7 +69,8 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = + new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); return this; } From 60203fca6a605ad158184e1e0ce5187e144a3ea7 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 19 Jan 2018 12:43:23 +0200 Subject: [PATCH 697/712] [SPARK-23127][DOC] Update FeatureHasher guide for categoricalCols parameter Update user guide entry for `FeatureHasher` to match the Scala / Python doc, to describe the `categoricalCols` parameter. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20293 from MLnick/SPARK-23127-catCol-userguide. --- docs/ml-features.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 72643137d96b1..10183c3e78c76 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co numeric or categorical features. Behavior and handling of column data types is as follows: - Numeric columns: For numeric features, the hash value of the column name is used to map the -feature value to its index in the feature vector. Numeric features are never treated as -categorical, even when they are integers. You must explicitly convert numeric columns containing -categorical features to strings first. +feature value to its index in the feature vector. By default, numeric features are not treated +as categorical (even when they are integers). To treat them as categorical, specify the relevant +columns using the `categoricalCols` parameter. - String columns: For categorical features, the hash value of the string "column_name=value" is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with From b74366481cc87490adf4e69d26389ec737548c15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jan 2018 12:48:42 +0200 Subject: [PATCH 698/712] [SPARK-23048][ML] Add OneHotEncoderEstimator document and examples ## What changes were proposed in this pull request? We have `OneHotEncoderEstimator` now and `OneHotEncoder` will be deprecated since 2.3.0. We should add `OneHotEncoderEstimator` into mllib document. We also need to provide corresponding examples for `OneHotEncoderEstimator` which are used in the document too. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20257 from viirya/SPARK-23048. --- docs/ml-features.md | 28 ++++++++----- ...=> JavaOneHotEncoderEstimatorExample.java} | 41 ++++++++----------- ...py => onehot_encoder_estimator_example.py} | 29 +++++++------ ...la => OneHotEncoderEstimatorExample.scala} | 40 ++++++++---------- 4 files changed, 68 insertions(+), 70 deletions(-) rename examples/src/main/java/org/apache/spark/examples/ml/{JavaOneHotEncoderExample.java => JavaOneHotEncoderEstimatorExample.java} (62%) rename examples/src/main/python/ml/{onehot_encoder_example.py => onehot_encoder_estimator_example.py} (65%) rename examples/src/main/scala/org/apache/spark/examples/ml/{OneHotEncoderExample.scala => OneHotEncoderEstimatorExample.scala} (65%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 10183c3e78c76..466a8fbe99cf6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -775,35 +775,43 @@ for more details on the API. -## OneHotEncoder +## OneHotEncoder (Deprecated since 2.3.0) -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). + +`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. + +## OneHotEncoderEstimator + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). + +`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
    -Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
    -Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
    -Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% include_example python/ml/onehot_encoder_estimator_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java similarity index 62% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java index 99af37676ba98..6f93cff94b725 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java @@ -23,9 +23,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,41 +34,37 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderExample { +public class JavaOneHotEncoderEstimatorExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderExample") + .appName("JavaOneHotEncoderEstimatorExample") .getOrCreate(); + // Note: categorical features are usually first encoded with StringIndexer // $example on$ List data = Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") + RowFactory.create(0.0, 1.0), + RowFactory.create(1.0, 0.0), + RowFactory.create(2.0, 1.0), + RowFactory.create(0.0, 2.0), + RowFactory.create(0.0, 1.0), + RowFactory.create(2.0, 0.0) ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset df = spark.createDataFrame(data, schema); - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - Dataset indexed = indexer.transform(df); + OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - - Dataset encoded = encoder.transform(indexed); + OneHotEncoderModel model = encoder.fit(df); + Dataset encoded = model.transform(df); encoded.show(); // $example off$ diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_estimator_example.py similarity index 65% rename from examples/src/main/python/ml/onehot_encoder_example.py rename to examples/src/main/python/ml/onehot_encoder_estimator_example.py index e1996c7f0a55b..2723e681cea7c 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_estimator_example.py @@ -18,32 +18,31 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.ml.feature import OneHotEncoderEstimator # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderExample")\ + .appName("OneHotEncoderEstimatorExample")\ .getOrCreate() + # Note: categorical features are usually first encoded with StringIndexer # $example on$ df = spark.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + ], ["categoryIndex1", "categoryIndex2"]) - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - - encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) + encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) + model = encoder.fit(df) + encoded = model.transform(df) encoded.show() # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala similarity index 65% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala index 274cc1268f4d1..45d816808ed8e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala @@ -19,38 +19,34 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +import org.apache.spark.ml.feature.OneHotEncoderEstimator // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderExample { +object OneHotEncoderEstimatorExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderExample") + .appName("OneHotEncoderEstimatorExample") .getOrCreate() + // Note: categorical features are usually first encoded with StringIndexer // $example on$ val df = spark.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") - - val encoded = encoder.transform(indexed) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + )).toDF("categoryIndex1", "categoryIndex2") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryVec1", "categoryVec2")) + val model = encoder.fit(df) + + val encoded = model.transform(df) encoded.show() // $example off$ From e41400c3c8aace9eb72e6134173f222627fb0faf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 19 Jan 2018 19:46:48 +0800 Subject: [PATCH 699/712] [SPARK-23089][STS] Recreate session log directory if it doesn't exist ## What changes were proposed in this pull request? When creating a session directory, Thrift should create the parent directory (i.e. /tmp/base_session_log_dir) if it is not present. It is common that many tools delete empty directories, so the directory may be deleted. This can cause the session log to be disabled. This was fixed in HIVE-12262: this PR brings it in Spark too. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20281 from mgaido91/SPARK-23089. --- .../hive/service/cli/session/HiveSessionImpl.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 47bfaa86021d6..108074cce3d6d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -223,6 +223,18 @@ private void configureSession(Map sessionConfMap) throws HiveSQL @Override public void setOperationLogSessionDir(File operationLogRootDir) { + if (!operationLogRootDir.exists()) { + LOG.warn("The operation log root directory is removed, recreating: " + + operationLogRootDir.getAbsolutePath()); + if (!operationLogRootDir.mkdirs()) { + LOG.warn("Unable to create operation log root directory: " + + operationLogRootDir.getAbsolutePath()); + } + } + if (!operationLogRootDir.canWrite()) { + LOG.warn("The operation log root directory is not writable: " + + operationLogRootDir.getAbsolutePath()); + } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { From e1c33b6cd14e4e1123814f4d040e3520db7d1ec9 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 19 Jan 2018 08:22:24 -0600 Subject: [PATCH 700/712] [SPARK-23024][WEB-UI] Spark ui about the contents of the form need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark ui about the contents of the form need to have hidden and show features, when the table records very much. Because sometimes you do not care about the record of the table, you just want to see the contents of the next table, but you have to scroll the scroll bar for a long time to see the contents of the next table. Currently we have about 500 workers, but I just wanted to see the logs for the running applications table. I had to scroll through the scroll bars for a long time to see the logs for the running applications table. In order to ensure functional consistency, I modified the Master Page, Worker Page, Job Page, Stage Page, Task Page, Configuration Page, Storage Page, Pool Page. fix before: ![1](https://user-images.githubusercontent.com/26266482/34805936-601ed628-f6bb-11e7-8dd3-d8413573a076.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/34805949-6af8afba-f6bb-11e7-89f4-ab16584916fb.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20216 from guoxiaolongzte/SPARK-23024. --- .../org/apache/spark/ui/static/webui.js | 30 ++++++++ .../deploy/master/ui/ApplicationPage.scala | 25 +++++-- .../spark/deploy/master/ui/MasterPage.scala | 63 ++++++++++++++--- .../spark/deploy/worker/ui/WorkerPage.scala | 52 +++++++++++--- .../apache/spark/ui/env/EnvironmentPage.scala | 48 +++++++++++-- .../apache/spark/ui/jobs/AllJobsPage.scala | 39 +++++++++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 67 +++++++++++++++--- .../org/apache/spark/ui/jobs/JobPage.scala | 68 ++++++++++++++++--- .../org/apache/spark/ui/jobs/PoolPage.scala | 13 +++- .../org/apache/spark/ui/jobs/StagePage.scala | 12 +++- .../apache/spark/ui/storage/StoragePage.scala | 12 +++- 11 files changed, 373 insertions(+), 56 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0fa1fcf25f8b9..e575c4c78970d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -50,4 +50,34 @@ function collapseTable(thisName, table){ // to remember if it's collapsed on each page reload $(function() { collapseTablePageLoad('collapse-aggregated-metrics','aggregated-metrics'); + collapseTablePageLoad('collapse-aggregated-executors','aggregated-executors'); + collapseTablePageLoad('collapse-aggregated-removedExecutors','aggregated-removedExecutors'); + collapseTablePageLoad('collapse-aggregated-workers','aggregated-workers'); + collapseTablePageLoad('collapse-aggregated-activeApps','aggregated-activeApps'); + collapseTablePageLoad('collapse-aggregated-activeDrivers','aggregated-activeDrivers'); + collapseTablePageLoad('collapse-aggregated-completedApps','aggregated-completedApps'); + collapseTablePageLoad('collapse-aggregated-completedDrivers','aggregated-completedDrivers'); + collapseTablePageLoad('collapse-aggregated-runningExecutors','aggregated-runningExecutors'); + collapseTablePageLoad('collapse-aggregated-runningDrivers','aggregated-runningDrivers'); + collapseTablePageLoad('collapse-aggregated-finishedExecutors','aggregated-finishedExecutors'); + collapseTablePageLoad('collapse-aggregated-finishedDrivers','aggregated-finishedDrivers'); + collapseTablePageLoad('collapse-aggregated-runtimeInformation','aggregated-runtimeInformation'); + collapseTablePageLoad('collapse-aggregated-sparkProperties','aggregated-sparkProperties'); + collapseTablePageLoad('collapse-aggregated-systemProperties','aggregated-systemProperties'); + collapseTablePageLoad('collapse-aggregated-classpathEntries','aggregated-classpathEntries'); + collapseTablePageLoad('collapse-aggregated-activeJobs','aggregated-activeJobs'); + collapseTablePageLoad('collapse-aggregated-completedJobs','aggregated-completedJobs'); + collapseTablePageLoad('collapse-aggregated-failedJobs','aggregated-failedJobs'); + collapseTablePageLoad('collapse-aggregated-poolTable','aggregated-poolTable'); + collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); + collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); + collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); + collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); + collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); + collapseTablePageLoad('collapse-aggregated-completedStages','aggregated-completedStages'); + collapseTablePageLoad('collapse-aggregated-failedStages','aggregated-failedStages'); + collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); + collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); + collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); }); \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 68e57b7564ad1..f699c75085fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -100,12 +100,29 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
    -

    Executor Summary ({allExecutors.length})

    - {executorsTable} + +

    + + Executor Summary ({allExecutors.length}) +

    +
    +
    + {executorsTable} +
    { if (removedExecutors.nonEmpty) { -

    Removed Executors ({removedExecutors.length})

    ++ - removedExecutorsTable + +

    + + Removed Executors ({removedExecutors.length}) +

    +
    ++ +
    + {removedExecutorsTable} +
    } }
    diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index bc0bf6a1d9700..c629937606b51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -128,15 +128,31 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Workers ({workers.length})

    - {workerTable} + +

    + + Workers ({workers.length}) +

    +
    +
    + {workerTable} +
    -

    Running Applications ({activeApps.length})

    - {activeAppsTable} + +

    + + Running Applications ({activeApps.length}) +

    +
    +
    + {activeAppsTable} +
    @@ -144,8 +160,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
    -

    Running Drivers ({activeDrivers.length})

    - {activeDriversTable} + +

    + + Running Drivers ({activeDrivers.length}) +

    +
    +
    + {activeDriversTable} +
    } @@ -154,8 +179,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Completed Applications ({completedApps.length})

    - {completedAppsTable} + +

    + + Completed Applications ({completedApps.length}) +

    +
    +
    + {completedAppsTable} +
    @@ -164,8 +198,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
    -

    Completed Drivers ({completedDrivers.length})

    - {completedDriversTable} + +

    + + Completed Drivers ({completedDrivers.length}) +

    +
    +
    + {completedDriversTable} +
    } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index ce84bc4dae32c..8b98ae56fc108 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -77,24 +77,60 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
    -

    Running Executors ({runningExecutors.size})

    - {runningExecutorTable} + +

    + + Running Executors ({runningExecutors.size}) +

    +
    +
    + {runningExecutorTable} +
    { if (runningDrivers.nonEmpty) { -

    Running Drivers ({runningDrivers.size})

    ++ - runningDriverTable + +

    + + Running Drivers ({runningDrivers.size}) +

    +
    ++ +
    + {runningDriverTable} +
    } } { if (finishedExecutors.nonEmpty) { -

    Finished Executors ({finishedExecutors.size})

    ++ - finishedExecutorTable + +

    + + Finished Executors ({finishedExecutors.size}) +

    +
    ++ +
    + {finishedExecutorTable} +
    } } { if (finishedDrivers.nonEmpty) { -

    Finished Drivers ({finishedDrivers.size})

    ++ - finishedDriverTable + +

    + + Finished Drivers ({finishedDrivers.size}) +

    +
    ++ +
    + {finishedDriverTable} +
    } }
    diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 43adab7a35d65..902eb92b854f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -48,10 +48,50 @@ private[ui] class EnvironmentPage( classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content = -

    Runtime Information

    {runtimeInformationTable} -

    Spark Properties

    {sparkPropertiesTable} -

    System Properties

    {systemPropertiesTable} -

    Classpath Entries

    {classpathEntriesTable} + +

    + + Runtime Information +

    +
    +
    + {runtimeInformationTable} +
    + +

    + + Spark Properties +

    +
    +
    + {sparkPropertiesTable} +
    + +

    + + System Properties +

    +
    +
    + {systemPropertiesTable} +
    + +

    + + Classpath Entries +

    +
    +
    + {classpathEntriesTable} +
    UIUtils.headerSparkPage("Environment", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ff916bb6a5759..e3b72f1f34859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -363,16 +363,43 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We store.executorList(false), startTime) if (shouldShowActiveJobs) { - content ++=

    Active Jobs ({activeJobs.size})

    ++ - activeJobsTable + content ++= + +

    + + Active Jobs ({activeJobs.size}) +

    +
    ++ +
    + {activeJobsTable} +
    } if (shouldShowCompletedJobs) { - content ++=

    Completed Jobs ({completedJobNumStr})

    ++ - completedJobsTable + content ++= + +

    + + Completed Jobs ({completedJobNumStr}) +

    +
    ++ +
    + {completedJobsTable} +
    } if (shouldShowFailedJobs) { - content ++=

    Failed Jobs ({failedJobs.size})

    ++ - failedJobsTable + content ++= + +

    + + Failed Jobs ({failedJobs.size}) +

    +
    ++ +
    + {failedJobsTable} +
    } val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b1e343451e28e..606dc1e180e5b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -116,26 +116,75 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

    Fair Scheduler Pools ({pools.size})

    ++ poolTable.toNodeSeq + +

    + + Fair Scheduler Pools ({pools.size}) +

    +
    ++ +
    + {poolTable.toNodeSeq} +
    } else { Seq.empty[Node] } } if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq + content ++= + +

    + + Active Stages ({activeStages.size}) +

    +
    ++ +
    + {activeStagesTable.toNodeSeq} +
    } if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingStages.size})

    ++ - pendingStagesTable.toNodeSeq + content ++= + +

    + + Pending Stages ({pendingStages.size}) +

    +
    ++ +
    + {pendingStagesTable.toNodeSeq} +
    } if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStageNumStr})

    ++ - completedStagesTable.toNodeSeq + content ++= + +

    + + Completed Stages ({completedStageNumStr}) +

    +
    ++ +
    + {completedStagesTable.toNodeSeq} +
    } if (shouldShowFailedStages) { - content ++=

    Failed Stages ({numFailedStages})

    ++ - failedStagesTable.toNodeSeq + content ++= + +

    + + Failed Stages ({numFailedStages}) +

    +
    ++ +
    + {failedStagesTable.toNodeSeq} +
    } UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index bf59152c8c0cd..c27f30c21a843 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -340,24 +340,72 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP jobId, store.operationGraphForJob(jobId)) if (shouldShowActiveStages) { - content ++=

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq + content ++= + +

    + + Active Stages ({activeStages.size}) +

    +
    ++ +
    + {activeStagesTable.toNodeSeq} +
    } if (shouldShowPendingStages) { - content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

    + + Pending Stages ({pendingOrSkippedStages.size}) +

    +
    ++ +
    + {pendingOrSkippedStagesTable.toNodeSeq} +
    } if (shouldShowCompletedStages) { - content ++=

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq + content ++= + +

    + + Completed Stages ({completedStages.size}) +

    +
    ++ +
    + {completedStagesTable.toNodeSeq} +
    } if (shouldShowSkippedStages) { - content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

    + + Skipped Stages ({pendingOrSkippedStages.size}) +

    +
    ++ +
    + {pendingOrSkippedStagesTable.toNodeSeq} +
    } if (shouldShowFailedStages) { - content ++=

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq + content ++= + +

    + + Failed Stages ({failedStages.size}) +

    +
    ++ +
    + {failedStagesTable.toNodeSeq} +
    } UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 98fbd7aceaa11..a3e1f13782e30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -51,7 +51,18 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolTable = new PoolTable(Map(pool -> uiPool), parent) var content =

    Summary

    ++ poolTable.toNodeSeq if (activeStages.nonEmpty) { - content ++=

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq + content ++= + +

    + + Active Stages ({activeStages.size}) +

    +
    ++ +
    + {activeStagesTable.toNodeSeq} +
    } UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index af78373ddb4b2..25bee33028393 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -486,8 +486,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ aggMetrics ++ maybeAccumulableTable ++ -

    Tasks ({totalTasksNumStr})

    ++ - taskTableHTML ++ jsForScrollingDownToTaskTable + +

    + + Tasks ({totalTasksNumStr}) +

    +
    ++ +
    + {taskTableHTML ++ jsForScrollingDownToTaskTable} +
    UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b8aec9890247a..68d946574a37b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -41,8 +41,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends Nil } else {
    -

    RDDs

    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + +

    + + RDDs +

    +
    +
    + {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
    } } From 6c39654efcb2aa8cb4d082ab7277a6fa38fb48e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 19 Jan 2018 22:47:18 +0800 Subject: [PATCH 701/712] [SPARK-23000][TEST] Keep Derby DB Location Unchanged After Session Cloning ## What changes were proposed in this pull request? After session cloning in `TestHive`, the conf of the singleton SparkContext for derby DB location is changed to a new directory. The new directory is created in `HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false)`. This PR is to keep the conf value of `ConfVars.METASTORECONNECTURLKEY.varname` unchanged during the session clone. ## How was this patch tested? The issue can be reproduced by the command: > build/sbt -Phive "hive/test-only org.apache.spark.sql.hive.HiveSessionStateSuite org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite" Also added a test case. Author: gatorsmile Closes #20328 from gatorsmile/fixTestFailure. --- .../org/apache/spark/sql/SessionStateSuite.scala | 5 +---- .../apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- .../spark/sql/hive/HiveSessionStateSuite.scala | 16 +++++++++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5d75f5835bf9e..4efae4c46c2e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll -import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite @@ -28,8 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class SessionStateSuite extends SparkFunSuite - with BeforeAndAfterEach with BeforeAndAfterAll { +class SessionStateSuite extends SparkFunSuite { /** * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b6be00dbb3a73..c84131fc3212a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -180,7 +180,13 @@ private[hive] class TestHiveSparkSession( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", // scratch directory used by Hive's metastore client ConfVars.SCRATCHDIR.varname -> TestHiveContext.makeScratchDir().toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") ++ + // After session cloning, the JDBC connect string for a JDBC metastore should not be changed. + existingSharedState.map { state => + val connKey = + state.sparkContext.hadoopConfiguration.get(ConfVars.METASTORECONNECTURLKEY.varname) + ConfVars.METASTORECONNECTURLKEY.varname -> connKey + } metastoreTempConf.foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index f7da3c4cbb0aa..ecc09cdcdbeaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterEach +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -25,8 +25,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton /** * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ -class HiveSessionStateSuite extends SessionStateSuite - with TestHiveSingleton with BeforeAndAfterEach { +class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { override def beforeAll(): Unit = { // Reuse the singleton session @@ -39,4 +38,15 @@ class HiveSessionStateSuite extends SessionStateSuite activeSession = null super.afterAll() } + + test("Clone then newSession") { + val sparkSession = hiveContext.sparkSession + val conf = sparkSession.sparkContext.hadoopConfiguration + val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + sparkSession.cloneSession() + sparkSession.sharedState.externalCatalog.client.newSession() + val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + assert(oldValue == newValue, + "cloneSession and then newSession should not affect the Derby directory") + } } From 606a7485f12c5d5377c50258006c353ba5e49c3f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 19 Jan 2018 09:28:35 -0600 Subject: [PATCH 702/712] [SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse ## What changes were proposed in this pull request? `ML.Vectors#sparse(size: Int, elements: Seq[(Int, Double)])` support zero-length ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20275 from zhengruifeng/SparseVector_size. --- .../scala/org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/linalg/VectorsSuite.scala | 14 ++++++++++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 3 +-- .../apache/spark/mllib/linalg/VectorsSuite.scala | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 941b6eca568d3..5824e463ca1aa 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") ( // validate the data { - require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 79acef8214d88..0a316f57f811b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fd9605c013625..6e68d9684a672 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -326,8 +326,6 @@ object Vectors { */ @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => @@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") ( @Since("1.0.0") val indices: Array[Int], @Since("1.0.0") val values: Array[Double]) extends Vector { + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 4074bead421e6..217b4a35438fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } From d8aaa771e249b3f54b57ce24763e53fd65a0dbf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Jan 2018 08:58:21 -0800 Subject: [PATCH 703/712] [SPARK-23149][SQL] polish ColumnarBatch ## What changes were proposed in this pull request? Several cleanups in `ColumnarBatch` * remove `schema`. The `ColumnVector`s inside `ColumnarBatch` already have the data type information, we don't need this `schema`. * remove `capacity`. `ColumnarBatch` is just a wrapper of `ColumnVector`s, not builders, it doesn't need a capacity property. * remove `DEFAULT_BATCH_SIZE`. As a wrapper, `ColumnarBatch` can't decide the batch size, it should be decided by the reader, e.g. parquet reader, orc reader, cached table reader. The default batch size should also be defined by the reader. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20316 from cloud-fan/columnar-batch. --- .../orc/OrcColumnarBatchReader.java | 49 +++++++------------ .../SpecificParquetRecordReaderBase.java | 12 ++--- .../VectorizedParquetRecordReader.java | 24 ++++----- .../vectorized/ColumnVectorUtils.java | 18 +++---- .../spark/sql/vectorized/ColumnarBatch.java | 20 +------- .../VectorizedHashMapGenerator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 5 +- .../python/ArrowEvalPythonExec.scala | 8 +-- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 3 +- .../vectorized/ColumnarBatchSuite.scala | 7 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 3 +- 13 files changed, 61 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 36fdf2bdf84d2..89bae4326e93b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,18 +49,8 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - - /** - * The default size of batch. We use this value for ORC reader to make it consistent with Spark's - * columnar batch, because their default batch sizes are different like the following: - * - * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 - * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 - */ - private static final int DEFAULT_SIZE = 4 * 1024; - - // ORC File Reader - private Reader reader; + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -98,22 +88,22 @@ public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @Override - public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + public ColumnarBatch getCurrentValue() { return columnarBatch; } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() throws IOException { return recordReader.getProgress(); } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { return nextBatch(); } @@ -134,16 +124,15 @@ public void close() throws IOException { * Please note that `initBatch` is needed to be called after this. */ @Override - public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { FileSplit fileSplit = (FileSplit)inputSplit; Configuration conf = taskAttemptContext.getConfiguration(); - reader = OrcFile.createReader( + Reader reader = OrcFile.createReader( fileSplit.getPath(), OrcFile.readerOptions(conf) .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) .filesystem(fileSplit.getPath().getFileSystem(conf))); - Reader.Options options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); recordReader = reader.rows(options); @@ -159,7 +148,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(DEFAULT_SIZE); + batch = orcSchema.createRowBatch(CAPACITY); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -171,19 +160,17 @@ public void initBatch( resultSchema = resultSchema.add(f); } - int capacity = DEFAULT_SIZE; - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, capacity); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } @@ -196,7 +183,7 @@ public void initBatch( } } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; @@ -206,8 +193,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); + OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); + missingCol.putNulls(0, CAPACITY); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -219,14 +206,14 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; } } - columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); + columnarBatch = new ColumnarBatch(orcVectorWrappers); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 80c2f491b48ce..e65cd252c3ddf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -170,7 +170,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont * Returns the list of files at 'path' recursively. This skips files that are ignored normally * by MapReduce. */ - public static List listDirectory(File path) throws IOException { + public static List listDirectory(File path) { List result = new ArrayList<>(); if (path.isDirectory()) { for (File f: path.listFiles()) { @@ -231,7 +231,7 @@ protected void initialize(String path, List columns) throws IOException } @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @@ -259,7 +259,7 @@ public ValuesReaderIntIterator(ValuesReader delegate) { } @Override - int nextInt() throws IOException { + int nextInt() { return delegate.readInteger(); } } @@ -279,15 +279,15 @@ int nextInt() throws IOException { protected static final class NullIntIterator extends IntIterator { @Override - int nextInt() throws IOException { return 0; } + int nextInt() { return 0; } } /** * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, - ColumnDescriptor descriptor) throws IOException { + protected static IntIterator createRLEIterator( + int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); return new RLEIntIterator( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index cd745b1f0e4e3..bb1b23611a7d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,6 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -152,7 +155,7 @@ public void close() throws IOException { } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { resultBatch(); if (returnColumnarBatch) return nextBatch(); @@ -165,13 +168,13 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public Object getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() { if (returnColumnarBatch) return columnarBatch; return columnarBatch.getRow(batchIdx - 1); } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return (float) rowsReturned / totalRowCount; } @@ -181,7 +184,7 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch( + private void initBatch( MemoryMode memMode, StructType partitionColumns, InternalRow partitionValues) { @@ -195,13 +198,12 @@ public void initBatch( } } - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } - columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -213,13 +215,13 @@ public void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } } - public void initBatch() { + private void initBatch() { initBatch(MEMORY_MODE, null, null); } @@ -255,7 +257,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b5cbe8e2839ba..5ee8cc8da2309 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -118,19 +118,19 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } } else { if (t == DataTypes.BooleanType) { - dst.appendBoolean(((Boolean)o).booleanValue()); + dst.appendBoolean((Boolean) o); } else if (t == DataTypes.ByteType) { - dst.appendByte(((Byte) o).byteValue()); + dst.appendByte((Byte) o); } else if (t == DataTypes.ShortType) { - dst.appendShort(((Short)o).shortValue()); + dst.appendShort((Short) o); } else if (t == DataTypes.IntegerType) { - dst.appendInt(((Integer)o).intValue()); + dst.appendInt((Integer) o); } else if (t == DataTypes.LongType) { - dst.appendLong(((Long)o).longValue()); + dst.appendLong((Long) o); } else if (t == DataTypes.FloatType) { - dst.appendFloat(((Float)o).floatValue()); + dst.appendFloat((Float) o); } else if (t == DataTypes.DoubleType) { - dst.appendDouble(((Double)o).doubleValue()); + dst.appendDouble((Double) o); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); @@ -192,7 +192,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + int capacity = 4 * 1024; WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); @@ -208,7 +208,7 @@ public static ColumnarBatch toBatch( } n++; } - ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); + ColumnarBatch batch = new ColumnarBatch(columnVectors); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 9ae1c6d9993f0..4dc826cf60c15 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; -import org.apache.spark.sql.types.StructType; /** * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this @@ -28,10 +27,6 @@ * the entire data loading process. */ public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; private int numRows; private final ColumnVector[] columns; @@ -82,7 +77,6 @@ public void remove() { * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { - assert(numRows <= this.capacity); this.numRows = numRows; } @@ -96,16 +90,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - /** * Returns the column at `ordinal`. */ @@ -120,10 +104,8 @@ public InternalRow getRow(int rowId) { return row; } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; + public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.capacity = capacity; this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0cf9b53ce1d5d..eb48584d0c1ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -94,7 +94,7 @@ class VectorizedHashMapGenerator( | | public $generatedClassName() { | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); - | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); + | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcd1aa0890ba3..7487564ed64da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -175,7 +175,7 @@ private[sql] object ArrowConverters { new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3565ee3af1b9f..28b3875505cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,11 +78,10 @@ case class InMemoryTableScanExec( } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } - val columnarBatch = new ColumnarBatch( - columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) columnarBatch.setNumRows(rowCount) - for (i <- 0 until attributes.length) { + for (i <- attributes.indices) { ColumnAccessor.decompress( cachedColumnarBatch.buffers(columnIndices(i)), columnarBatch.column(i).asInstanceOf[WritableColumnVector], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c06bc7b66ff39..47b146f076b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -74,8 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) + val outputTypes = output.drop(child.output.length).map(_.dataType) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) @@ -90,8 +89,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } else { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index dc5ba96e69aec..5fcdcddca7d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -138,7 +138,7 @@ class ArrowPythonRunner( if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { - val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) batch } else { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 98d6a53b54d28..a5d77a90ece42 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,8 +69,7 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = - new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = new ColumnarBatch(vectors); return this; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 675f06b31b970..cd90681ecabc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -875,14 +875,13 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val capacity = 4 * 1024 val columns = schema.fields.map { field => allocate(capacity, field.dataType, memMode) } - val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) + val batch = new ColumnarBatch(columns.toArray) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] @@ -1153,7 +1152,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) - val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + val batch = new ColumnarBatch(columnVectors.toArray) batch.setNumRows(11) assert(batch.numCols() == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a89f7c55bf4f7..0ca29524c6d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -333,8 +333,7 @@ class BatchReadTask(start: Int, end: Int) private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch( - new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start From 73d3b230f3816a854a181c0912d87b180e347271 Mon Sep 17 00:00:00 2001 From: foxish Date: Fri, 19 Jan 2018 10:23:13 -0800 Subject: [PATCH 704/712] [SPARK-23104][K8S][DOCS] Changes to Kubernetes scheduler documentation ## What changes were proposed in this pull request? Docs changes: - Adding a warning that the backend is experimental. - Removing a defunct internal-only option from documentation - Clarifying that node selectors can be used right away, and other minor cosmetic changes ## How was this patch tested? Docs only change Author: foxish Closes #20314 from foxish/ambiguous-docs. --- docs/cluster-overview.md | 4 ++-- docs/running-on-kubernetes.md | 43 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 658e67f99dd71..7277e2fb2731d 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,8 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) -is an open-source platform that provides container-centric infrastructure. +* [Kubernetes](running-on-kubernetes.html) -- an open-source system for automating deployment, scaling, + and management of containerized applications. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d6b1735ce5550..3c7586e8544ba 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -8,6 +8,10 @@ title: Running Spark on Kubernetes Spark can run on clusters managed by [Kubernetes](https://kubernetes.io). This feature makes use of native Kubernetes scheduler that has been added to Spark. +**The Kubernetes scheduler is currently experimental. +In future versions, there may be behavioral changes around configuration, +container images and entrypoints.** + # Prerequisites * A runnable distribution of Spark 2.3 or above. @@ -41,11 +45,10 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling -decisions for driver and executor pods using advanced primitives like -[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) -and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) -in a future release. +The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +using the configuration property for it. It will be possible to use more advanced +scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. # Submitting Applications to Kubernetes @@ -62,8 +65,10 @@ use with the Kubernetes backend. Example usage is: - ./bin/docker-image-tool.sh -r -t my-tag build - ./bin/docker-image-tool.sh -r -t my-tag push +```bash +$ ./bin/docker-image-tool.sh -r -t my-tag build +$ ./bin/docker-image-tool.sh -r -t my-tag push +``` ## Cluster Mode @@ -94,7 +99,7 @@ must consist of lower case alphanumeric characters, `-`, and `.` and must start If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. ```bash -kubectl cluster-info +$ kubectl cluster-info Kubernetes master is running at http://127.0.0.1:6443 ``` @@ -105,7 +110,7 @@ authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. The local proxy can be started by: ```bash -kubectl proxy +$ kubectl proxy ``` If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to @@ -173,7 +178,7 @@ Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spar to stream logs from the application using: ```bash -kubectl -n= logs -f +$ kubectl -n= logs -f ``` The same logs can also be accessed through the @@ -186,7 +191,7 @@ The UI associated with any application can be accessed locally using [`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). ```bash -kubectl port-forward 4040:4040 +$ kubectl port-forward 4040:4040 ``` Then, the Spark driver UI can be accessed on `http://localhost:4040`. @@ -200,13 +205,13 @@ are errors during the running of the application, often, the best way to investi To get some basic information about the scheduling decisions made around the driver pod, you can run: ```bash -kubectl describe pod +$ kubectl describe pod ``` If the pod has encountered a runtime error, the status can be probed further using: ```bash -kubectl logs +$ kubectl logs ``` Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark @@ -254,7 +259,7 @@ To create a custom service account, a user can use the `kubectl create serviceac following command creates a service account named `spark`: ```bash -kubectl create serviceaccount spark +$ kubectl create serviceaccount spark ``` To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create @@ -263,7 +268,7 @@ for `ClusterRoleBinding`) command. For example, the following command creates an namespace and grants it to the `spark` service account created above: ```bash -kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +$ kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default ``` Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a @@ -543,14 +548,6 @@ specific to Spark on Kubernetes. to avoid name conflicts. - - spark.kubernetes.executor.podNamePrefix - (none) - - Prefix for naming the executor pods. - If not set, the executor pod name is set to driver pod name suffixed by an integer. - - spark.kubernetes.executor.lostCheck.maxAttempts 10 From 07296a61c29eb074553956b6c0f92810ecf7bab2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 10:25:18 -0800 Subject: [PATCH 705/712] [INFRA] Close stale PR. Closes #20185. From fed2139f053fac4a9a6952ff0ab1cc2a5f657bd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:26:37 -0600 Subject: [PATCH 706/712] [SPARK-20664][CORE] Delete stale application data from SHS. Detect the deletion of event log files from storage, and remove data about the related application attempt in the SHS. Also contains code to fix SPARK-21571 based on code by ericvandenbergfb. Author: Marcelo Vanzin Closes #20138 from vanzin/SPARK-20664. --- .../deploy/history/FsHistoryProvider.scala | 297 +++++++++++------- .../history/FsHistoryProviderSuite.scala | 117 ++++++- .../deploy/history/HistoryServerSuite.scala | 4 +- 3 files changed, 306 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 94c80ebd55e74..f9d0b5ee4e23e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -29,7 +29,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem @@ -116,8 +116,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs // and applications between check task and clean task. - private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() - .setNameFormat("spark-history-task-%d").setDaemon(true).build()) + private val pool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-history-task-%d") // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) @@ -174,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (!conf.contains("spark.testing")) { + if (Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() @@ -275,7 +274,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { Some(load(appId).toApplicationInfo()) } catch { - case e: NoSuchElementException => + case _: NoSuchElementException => None } } @@ -405,49 +404,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - // scan for modified applications, replay and merge them - val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) + + val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && - recordedFileSize(entry.getPath()) < entry.getLen() + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + } + .filter { entry => + try { + val info = listing.read(classOf[LogInfo], entry.getPath().toString()) + if (info.fileSize < entry.getLen()) { + // Log size has changed, it should be parsed. + true + } else { + // If the SHS view has a valid application, update the time the file was last seen so + // that the entry is not deleted from the SHS listing. + if (info.appId.isDefined) { + listing.write(info.copy(lastProcessed = newLastScanTime)) + } + false + } + } catch { + case _: NoSuchElementException => + // If the file is currently not being tracked by the SHS, add an entry for it and try + // to parse it. This will allow the cleaner code to detect the file as stale later on + // if it was not possible to parse it. + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, + entry.getLen())) + entry.getLen() > 0 + } } .sortWith { case (entry1, entry2) => entry1.getModificationTime() > entry2.getModificationTime() } - if (logInfos.nonEmpty) { - logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + if (updated.nonEmpty) { + logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - var tasks = mutable.ListBuffer[Future[_]]() - - try { - for (file <- logInfos) { - tasks += replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(file) + val tasks = updated.map { entry => + try { + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) }) + } catch { + // let the iteration over the updated entries break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + null } - } catch { - // let the iteration over logInfos break, since an exception on - // replayExecutor.submit (..) indicates the ExecutorService is unable - // to take any more submissions at this time - - case e: Exception => - logError(s"Exception while submitting event log for replay", e) - } + }.filter(_ != null) pendingReplayTasksCount.addAndGet(tasks.size) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. tasks.foreach { task => try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. task.get() } catch { case e: InterruptedException => @@ -459,13 +479,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + // Delete all information about applications whose log files disappeared from storage. + // This is done by identifying the event logs which were not touched by the current + // directory scan. + // + // Only entries with valid applications are cleaned up here. Cleaning up invalid log + // files is done by the periodic cleaner task. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .last(newLastScanTime - 1) + .asScala + .toList + stale.foreach { log => + log.appId.foreach { appId => + cleanAppData(appId, log.attemptId, log.logPath) + listing.delete(classOf[LogInfo], log.logPath) + } + } + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } - private def getNewLastScanTime(): Long = { + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { + try { + val app = load(appId) + val (attempt, others) = app.attempts.partition(_.info.attemptId == attemptId) + + assert(attempt.isEmpty || attempt.size == 1) + val isStale = attempt.headOption.exists { a => + if (a.logPath != new Path(logPath).getName()) { + // If the log file name does not match, then probably the old log file was from an + // in progress application. Just return that the app should be left alone. + false + } else { + val maybeUI = synchronized { + activeUIs.remove(appId -> attemptId) + } + + maybeUI.foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + + diskManager.foreach(_.release(appId, attemptId, delete = true)) + true + } + } + + if (isStale) { + if (others.nonEmpty) { + val newAppInfo = new ApplicationInfoWrapper(app.info, others) + listing.write(newAppInfo) + } else { + listing.delete(classOf[ApplicationInfoWrapper], appId) + } + } + } catch { + case _: NoSuchElementException => + } + } + + private[history] def getNewLastScanTime(): Long = { val fileName = "." + UUID.randomUUID().toString val path = new Path(logDir, fileName) val fos = fs.create(path) @@ -530,7 +607,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -544,73 +621,78 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, bus, eventsFilter = eventsFilter) - listener.applicationInfo.foreach { app => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + val (appId, attemptId) = listener.applicationInfo match { + case Some(app) => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - } - addListing(app) + addListing(app) + (Some(app.info.id), app.attempts.head.info.attemptId) + + case _ => + // If the app hasn't written down its app ID to the logs, still record the entry in the + // listing db, with an empty ID. This will make the log eligible for deletion if the app + // does not make progress after the configured max log age. + (None, None) } - listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) + listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ - private[history] def cleanLogs(): Unit = { - var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None - try { - val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - - // Iterate descending over all applications whose oldest attempt happened before maxTime. - iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) - .index("oldestAttempt") - .reverse() - .first(maxTime) - .closeableIterator()) - - iterator.get.asScala.foreach { app => - // Applications may have multiple attempts, some of which may not need to be deleted yet. - val (remaining, toDelete) = app.attempts.partition { attempt => - attempt.info.lastUpdated.getTime() >= maxTime - } + private[history] def cleanLogs(): Unit = Utils.tryLog { + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - if (remaining.nonEmpty) { - val newApp = new ApplicationInfoWrapper(app.info, remaining) - listing.write(newApp) - } + val expired = listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .asScala + .toList + expired.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - toDelete.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - try { - listing.delete(classOf[LogInfo], logPath.toString()) - } catch { - case _: NoSuchElementException => - logDebug(s"Log info entry for $logPath not found.") - } - try { - fs.delete(logPath, true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - } - } + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) + } - if (remaining.isEmpty) { - listing.delete(app.getClass(), app.id) - } + toDelete.foreach { attempt => + logInfo(s"Deleting expired event log for ${attempt.logPath}") + val logPath = new Path(logDir, attempt.logPath) + listing.delete(classOf[LogInfo], logPath.toString()) + cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) + deleteLog(logPath) + } + + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) + } + } + + // Delete log files that don't have a valid application and exceed the configured max age. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .toList + stale.foreach { log => + if (log.appId.isEmpty) { + logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") + deleteLog(new Path(log.logPath)) + listing.delete(classOf[LogInfo], log.logPath) } - } catch { - case t: Exception => logError("Exception while cleaning logs", t) - } finally { - iterator.foreach(_.close()) } } @@ -631,12 +713,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // an error the other way -- if we report a size bigger (ie later) than the file that is // actually read, we may never refresh the app. FileStatus is guaranteed to be static // after it's created, so we get a file size that is no bigger than what is actually read. - val logInput = EventLoggingListener.openEventLog(logPath, fs) - try { - bus.replay(logInput, logPath.toString, !isCompleted, eventsFilter) + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !isCompleted, eventsFilter) logInfo(s"Finished parsing $logPath") - } finally { - logInput.close() } } @@ -703,18 +782,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return the last known size of the given event log, recorded the last time the file - * system scanner detected a change in the file. - */ - private def recordedFileSize(log: Path): Long = { - try { - listing.read(classOf[LogInfo], log.toString()).fileSize - } catch { - case _: NoSuchElementException => 0L - } - } - private def load(appId: String): ApplicationInfoWrapper = { listing.read(classOf[ApplicationInfoWrapper], appId) } @@ -773,11 +840,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(status.getLen(), isCompressed) val newStorePath = try { - val store = KVUtils.open(lease.tmpPath, metadata) - try { + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => rebuildAppStore(store, status, attempt.info.lastUpdated.getTime()) - } finally { - store.close() } lease.commit(appId, attempt.info.attemptId) } catch { @@ -806,6 +870,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } + private def deleteLog(log: Path): Unit = { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } + } + } private[history] object FsHistoryProvider { @@ -832,8 +907,16 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +/** + * Tracking info for event logs detected in the configured log directory. Tracks both valid and + * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner + * can know what log files are safe to delete. + */ private[history] case class LogInfo( @KVIndexParam logPath: String, + @KVIndexParam("lastProcessed") lastProcessed: Long, + appId: Option[String], + attemptId: Option[String], fileSize: Long) private[history] class AttemptInfoWrapper( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 84ee01c7f5aaf..787de59edf465 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.Mockito.{doReturn, mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -149,8 +149,10 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { var mergeApplicationListingCall = 0 - override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - super.mergeApplicationListing(fileStatus) + override protected def mergeApplicationListing( + fileStatus: FileStatus, + lastSeen: Long): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen) mergeApplicationListingCall += 1 } } @@ -663,6 +665,115 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc freshUI.get.ui.store.job(0) } + test("clean up stale app information") { + val storeDir = Utils.createTempDir() + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val provider = spy(new FsHistoryProvider(conf)) + val appId = "new1" + + // Write logs for two app attempts. + doReturn(1L).when(provider).getNewLastScanTime() + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } + + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + doReturn(2L).when(provider).getNewLastScanTime() + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + doReturn(3L).when(provider).getNewLastScanTime() + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } + } + + test("SPARK-21571: clean up removes invalid history files") { + // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that + // until we figure out what's causing the problem. + val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") + val provider = new FsHistoryProvider(conf, clock) { + override def getNewLastScanTime(): Long = clock.getTimeMillis() + } + + // Create 0-byte size inprogress and complete files + var logCount = 0 + var validLogCount = 0 + + val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true) + emptyInProgress.createNewFile() + emptyInProgress.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val slowApp = newLogFile("slowApp", None, inProgress = true) + slowApp.createNewFile() + slowApp.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false) + emptyFinished.createNewFile() + emptyFinished.setLastModified(clock.getTimeMillis()) + logCount += 1 + + // Create an incomplete log file, has an end record but no start record. + val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false) + writeFile(corrupt, true, None, SparkListenerApplicationEnd(0)) + corrupt.setLastModified(clock.getTimeMillis()) + logCount += 1 + + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Move the clock forward 1 day and scan the files again. They should still be there. + clock.advance(TimeUnit.DAYS.toMillis(1)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Update the slow app to contain valid info. Code should detect the change and not clean + // it up. + writeFile(slowApp, true, None, + SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None)) + slowApp.setLastModified(clock.getTimeMillis()) + validLogCount += 1 + + // Move the clock forward another 2 days and scan the files again. This time the cleaner should + // pick up the invalid files and get rid of them. + clock.advance(TimeUnit.DAYS.toMillis(2)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === validLogCount) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87778dda0e2c8..7aa60f2b60796 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -564,7 +564,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit() + ShutdownHookManager.registerShutdownDeleteDir(logDir) } test("ui and api authorization checks") { From aa3a1276f9e23ffbb093d00743e63cd4369f9f57 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:32:20 -0600 Subject: [PATCH 707/712] [SPARK-23103][CORE] Ensure correct sort order for negative values in LevelDB. The code was sorting "0" as "less than" negative values, which is a little wrong. Fix is simple, most of the changes are the added test and related cleanup. Author: Marcelo Vanzin Closes #20284 from vanzin/SPARK-23103. --- .../spark/util/kvstore/LevelDBTypeInfo.java | 2 +- .../spark/util/kvstore/DBIteratorSuite.java | 7 +- .../spark/util/kvstore/LevelDBSuite.java | 77 ++++++++++--------- .../spark/status/AppStatusListenerSuite.scala | 8 +- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index 232ee41dd0b1f..f4d359234cb9e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -493,7 +493,7 @@ byte[] toKey(Object value, byte prefix) { byte[] key = new byte[bytes * 2 + 2]; long longValue = ((Number) value).longValue(); key[0] = prefix; - key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + key[1] = longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; for (int i = 0; i < key.length - 2; i++) { int masked = (int) ((longValue >>> (4 * i)) & 0xF); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java index 9a81f86812cde..1e062437d1803 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -73,7 +73,9 @@ default BaseComparator reverse() { private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); - private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> { + return Integer.valueOf(t1.num).compareTo(t2.num); + }; private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); /** @@ -112,7 +114,8 @@ public void setup() throws Exception { t.key = "key" + i; t.id = "id" + i; t.name = "name" + RND.nextInt(MAX_ENTRIES); - t.num = RND.nextInt(MAX_ENTRIES); + // Force one item to have an integer value of zero to test the fix for SPARK-23103. + t.num = (i != 0) ? (int) RND.nextLong() : 0; t.child = "child" + (i % MIN_ENTRIES); allEntries.add(t); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 2b07d249d2022..b8123ac81d29a 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; @@ -74,11 +76,7 @@ public void testReopenAndVersionCheckDb() throws Exception { @Test public void testObjectWriteReadDelete() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); try { db.read(CustomType1.class, t.key); @@ -106,17 +104,9 @@ public void testObjectWriteReadDelete() throws Exception { @Test public void testMultipleObjectWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "key1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; - - CustomType1 t2 = new CustomType1(); - t2.key = "key2"; - t2.id = "id"; - t2.name = "name2"; - t2.child = "child2"; + CustomType1 t1 = createCustomType1(1); + CustomType1 t2 = createCustomType1(2); + t2.id = t1.id; db.write(t1); db.write(t2); @@ -142,11 +132,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception { @Test public void testMultipleTypesWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; + CustomType1 t1 = createCustomType1(1); IntKeyType t2 = new IntKeyType(); t2.key = 2; @@ -188,10 +174,7 @@ public void testMultipleTypesWriteReadDelete() throws Exception { public void testMetadata() throws Exception { assertNull(db.getMetadata(CustomType1.class)); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.setMetadata(t); assertEquals(t, db.getMetadata(CustomType1.class)); @@ -202,11 +185,7 @@ public void testMetadata() throws Exception { @Test public void testUpdate() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.write(t); @@ -222,13 +201,7 @@ public void testUpdate() throws Exception { @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { - CustomType1 t = new CustomType1(); - t.key = "key" + i; - t.id = "id" + i; - t.name = "name" + i; - t.child = "child" + i; - - db.write(t); + db.write(createCustomType1(i)); } KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); @@ -240,6 +213,36 @@ public void testSkip() throws Exception { assertFalse(it.hasNext()); } + @Test + public void testNegativeIndexValues() throws Exception { + List expected = Arrays.asList(-100, -50, 0, 50, 100); + + expected.stream().forEach(i -> { + try { + db.write(createCustomType1(i)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List results = StreamSupport + .stream(db.view(CustomType1.class).index("int").spliterator(), false) + .map(e -> e.num) + .collect(Collectors.toList()); + + assertEquals(expected, results); + } + + private CustomType1 createCustomType1(int i) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.num = i; + t.child = "child" + i; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ca66b6b9db890..e7981bec6d64b 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -894,15 +894,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val dropped = stages.drop(1).head // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be - // calculcated, we need at least one finished task. + // calculated, we need at least one finished task. The code in AppStatusStore uses + // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the + // task end event. time += 1 val task = createTasks(1, Array("1")).head listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) time += 1 task.markFinished(TaskState.FINISHED, time) + val metrics = TaskMetrics.empty + metrics.setExecutorRunTime(42L) listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, - "taskType", Success, task, null)) + "taskType", Success, task, metrics)) new AppStatusStore(store) .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) From f6da41b0150725fe96ccb2ee3b48840b207f47eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:14:24 -0800 Subject: [PATCH 708/712] [SPARK-23135][UI] Fix rendering of accumulators in the stage page. This follows the behavior of 2.2: only named accumulators with a value are rendered. Screenshot: ![accs](https://user-images.githubusercontent.com/1694083/35065700-df409114-fb82-11e7-87c1-550c3f674371.png) Author: Marcelo Vanzin Closes #20299 from vanzin/SPARK-23135. --- .../org/apache/spark/ui/jobs/StagePage.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 25bee33028393..0eb3190205c3e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -260,7 +260,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { - {acc.name}{acc.value} + if (acc.name != null && acc.value != null) { + {acc.name}{acc.value} + } else { + Nil + } } val accumulableTable = UIUtils.listingTable( accumulableHeaders, @@ -864,7 +868,7 @@ private[ui] class TaskPagedTable( {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} {if (hasAccumulators(stage)) { - accumulatorsInfo(task) + {accumulatorsInfo(task)} }} {if (hasInput(stage)) { metricInfo(task) { m => @@ -920,8 +924,12 @@ private[ui] class TaskPagedTable( } private def accumulatorsInfo(task: TaskData): Seq[Node] = { - task.accumulatorUpdates.map { acc => - Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + task.accumulatorUpdates.flatMap { acc => + if (acc.name != null && acc.update.isDefined) { + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
    + } else { + Nil + } } } @@ -985,7 +993,9 @@ private object ApiHelper { "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, "Errors" -> TaskIndexNames.ERROR) - def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasAccumulators(stageData: StageData): Boolean = { + stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } + } def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 From 793841c6b8b98b918dcf241e29f60ef125914db9 Mon Sep 17 00:00:00 2001 From: Kent Yao <11215016@zju.edu.cn> Date: Fri, 19 Jan 2018 15:49:29 -0800 Subject: [PATCH 709/712] [SPARK-21771][SQL] remove useless hive client in SparkSQLEnv ## What changes were proposed in this pull request? Once a meta hive client is created, it generates its SessionState which creates a lot of session related directories, some deleteOnExit, some does not. if a hive client is useless we may not create it at the very start. ## How was this patch tested? N/A cc hvanhovell cloud-fan Author: Kent Yao <11215016@zju.edu.cn> Closes #18983 from yaooqinn/patch-1. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 6b19f971b73bb..cbd75ad12d430 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,8 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - .client.newSession() + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) From 396cdfbea45232bacbc03bfaf8be4ea85d47d3fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Jan 2018 22:46:34 -0800 Subject: [PATCH 710/712] [SPARK-23091][ML] Incorrect unit test for approxQuantile ## What changes were proposed in this pull request? Narrow bound on approx quantile test to epsilon from 2*epsilon to match paper ## How was this patch tested? Existing tests. Author: Sean Owen Closes #20324 from srowen/SPARK-23091. --- .../apache/spark/sql/DataFrameStatSuite.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 5169d2b5fc6b2..8eae35325faea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -154,24 +154,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) - val error_single = 2 * 1000 * epsilon - val error_double = 2 * 2000 * epsilon + val errorSingle = 1000 * epsilon + val errorDouble = 2.0 * errorSingle - assert(math.abs(single1 - q1 * n) < error_single) - assert(math.abs(double2 - 2 * q2 * n) < error_double) - assert(math.abs(s1 - q1 * n) < error_single) - assert(math.abs(s2 - q2 * n) < error_single) - assert(math.abs(d1 - 2 * q1 * n) < error_double) - assert(math.abs(d2 - 2 * q2 * n) < error_double) + assert(math.abs(single1 - q1 * n) <= errorSingle) + assert(math.abs(double2 - 2 * q2 * n) <= errorDouble) + assert(math.abs(s1 - q1 * n) <= errorSingle) + assert(math.abs(s2 - q2 * n) <= errorSingle) + assert(math.abs(d1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(d2 - 2 * q2 * n) <= errorDouble) // Multiple columns val Array(Array(ms1, ms2), Array(md1, md2)) = df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) - assert(math.abs(ms1 - q1 * n) < error_single) - assert(math.abs(ms2 - q2 * n) < error_single) - assert(math.abs(md1 - 2 * q1 * n) < error_double) - assert(math.abs(md2 - 2 * q2 * n) < error_double) + assert(math.abs(ms1 - q1 * n) <= errorSingle) + assert(math.abs(ms2 - q2 * n) <= errorSingle) + assert(math.abs(md1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(md2 - 2 * q2 * n) <= errorDouble) } // quantile should be in the range [0.0, 1.0] From 84a076e0e9a38a26edf7b702c24fdbbcf1e697b9 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 20 Jan 2018 14:34:37 -0800 Subject: [PATCH 711/712] [SPARK-23165][DOC] Spelling mistake fix in quick-start doc. ## What changes were proposed in this pull request? Fix spelling in quick-start doc. ## How was this patch tested? Doc only. Author: Shashwat Anand Closes #20336 from ashashwat/SPARK-23165. --- docs/cloud-integration.md | 4 ++-- docs/configuration.md | 14 +++++++------- docs/graphx-programming-guide.md | 4 ++-- docs/monitoring.md | 8 ++++---- docs/quick-start.md | 6 +++--- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/sql-programming-guide.md | 8 ++++---- docs/storage-openstack-swift.md | 2 +- docs/streaming-programming-guide.md | 4 ++-- docs/structured-streaming-kafka-integration.md | 4 ++-- docs/structured-streaming-programming-guide.md | 6 +++--- docs/submitting-applications.md | 8 ++++---- 14 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 751a192da4ffd..c150d9efc06ff 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -180,10 +180,10 @@ under the path, not the number of *new* files, so it can become a slow operation The size of the window needs to be set to handle this. 1. Files only appear in an object store once they are completely written; there -is no need for a worklow of write-then-rename to ensure that files aren't picked up +is no need for a workflow of write-then-rename to ensure that files aren't picked up while they are still being written. Applications can write straight to the monitored directory. -1. Streams should only be checkpointed to an store implementing a fast and +1. Streams should only be checkpointed to a store implementing a fast and atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index eecb39dcafc9e..e7f2419cc2fa4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,7 +79,7 @@ Then, you can supply configuration values at runtime: {% endhighlight %} The Spark shell and [`spark-submit`](submitting-applications.html) -tool support two ways to load configurations dynamically. The first are command line options, +tool support two ways to load configurations dynamically. The first is command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. Running `./bin/spark-submit --help` will show the entire list of these options. @@ -413,7 +413,7 @@ Apart from these, the following properties are also available, and may be useful false Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), - or it will be displayed before the driver exiting. It also can be dumped into disk by + or it will be displayed before the driver exits. It also can be dumped into disk by sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. @@ -446,7 +446,7 @@ Apart from these, the following properties are also available, and may be useful true Reuse Python worker or not. If yes, it will use a fixed number of Python workers, - does not need to fork() a Python process for every tasks. It will be very useful + does not need to fork() a Python process for every task. It will be very useful if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -1294,7 +1294,7 @@ Apart from these, the following properties are also available, and may be useful spark.files.openCostInBytes 4194304 (4 MB) - The estimated cost to open a file, measured by the number of bytes could be scanned in the same + The estimated cost to open a file, measured by the number of bytes could be scanned at the same time. This is used when putting multiple files into a partition. It is better to over estimate, then the partitions with small files will be faster than partitions with bigger files. @@ -1855,8 +1855,8 @@ Apart from these, the following properties are also available, and may be useful spark.user.groups.mapping org.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user are determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider which can be specified to resolve a list of groups for a user. Note: This implementation supports only a Unix/Linux based environment. Windows environment is @@ -2465,7 +2465,7 @@ should be included on Spark's classpath: The location of these configuration files varies across Hadoop versions, but a common location is inside of `/etc/hadoop/conf`. Some tools create -configurations on-the-fly, but offer a mechanisms to download copies of them. +configurations on-the-fly, but offer a mechanism to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/conf/spark-env.sh` to a location containing the configuration files. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 46225dc598da8..5c97a248df4bc 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,7 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodically checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): @@ -928,7 +928,7 @@ switch to 2D-partitioning or other heuristics included in GraphX.

    -Once the edges have be partitioned the key challenge to efficient graph-parallel computation is +Once the edges have been partitioned the key challenge to efficient graph-parallel computation is efficiently joining vertex attributes with the edges. Because real-world graphs typically have more edges than vertices, we move vertex attributes to the edges. Because not all partitions will contain edges adjacent to all vertices we internally maintain a routing table which identifies where diff --git a/docs/monitoring.md b/docs/monitoring.md index f8d3ce91a0691..6f6cfc1288d73 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -118,7 +118,7 @@ The history server can be configured as follows: The number of applications to retain UI data for in the cache. If this cap is exceeded, then the oldest applications will be removed from the cache. If an application is not in the cache, - it will have to be loaded from disk if its accessed from the UI. + it will have to be loaded from disk if it is accessed from the UI. @@ -407,7 +407,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running -The number of jobs and stages which can retrieved is constrained by the same retention +The number of jobs and stages which can be retrieved is constrained by the same retention mechanism of the standalone Spark UI; `"spark.ui.retainedJobs"` defines the threshold value triggering garbage collection on jobs, and `spark.ui.retainedStages` that for stages. Note that the garbage collection takes place on playback: it is possible to retrieve @@ -422,10 +422,10 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future as a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. -Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +Note that even when examining the UI of running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. diff --git a/docs/quick-start.md b/docs/quick-start.md index 200b97230e866..07c520cbee6be 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -67,7 +67,7 @@ res3: Long = 15 ./bin/pyspark -Or if PySpark is installed with pip in your current enviroment: +Or if PySpark is installed with pip in your current environment: pyspark @@ -156,7 +156,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).alias("word")).groupBy("word").count() {% endhighlight %} -Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: +Here, we use the `explode` function in `select`, to transform a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() @@ -422,7 +422,7 @@ $ YOUR_SPARK_HOME/bin/spark-submit \ Lines with a: 46, Lines with b: 23 {% endhighlight %} -If you have PySpark pip installed into your enviroment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. +If you have PySpark pip installed into your environment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. {% highlight bash %} # Use the Python interpreter to run your application diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 382cbfd5301b0..2bb5ecf1b8509 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -154,7 +154,7 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispacther, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. +By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e7edec5990363..e4f5a0c659e66 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -445,7 +445,7 @@ To use a custom metrics.properties for the application master and executors, upd yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be configured in yarn-site.xml. This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use - FileAppender or another appender that can handle the files being removed while its running. Based + FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. diff --git a/docs/security.md b/docs/security.md index 15aadf07cf873..bebc28ddbfb0e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -62,7 +62,7 @@ component-specific configuration namespaces used to override the default setting -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e2e48a0ef249..502c0a8c37e01 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1253,7 +1253,7 @@ provide a ClassTag. (Note that this is different than the Spark SQL JDBC server, which allows other applications to run queries using Spark SQL). -To get started you will need to include the JDBC driver for you particular database on the +To get started you will need to include the JDBC driver for your particular database on the spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: @@ -1793,7 +1793,7 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. @@ -1821,7 +1821,7 @@ options. transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ - APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the single-node data frame notion in these languages. - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` @@ -1997,7 +1997,7 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +of either language should use `SQLContext` and `DataFrame`. In general these classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index f4bb2353e3c49..1dd54719b21aa 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -42,7 +42,7 @@ Create core-site.xml and place it inside Spark's conf The main category of parameters that should be configured are the authentication parameters required by Keystone. -The following table contains a list of Keystone mandatory parameters. PROVIDER can be +The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 868acc41226dc..ffda36d64a770 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -74,7 +74,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. -// The master requires 2 cores to prevent from a starvation scenario. +// The master requires 2 cores to prevent a starvation scenario. val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) @@ -172,7 +172,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Note that we defined the transformation using a [FlatMapFunction](api/scala/index.html#org.apache.spark.api.java.function.FlatMapFunction) object. As we will discover along the way, there are a number of such convenience classes in the Java API -that help define DStream transformations. +that help defines DStream transformations. Next, we want to count these words. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 461c29ce1ba89..5647ec6bc5797 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -125,7 +125,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") ### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, -you can create an Dataset/DataFrame for a defined range of offsets. +you can create a Dataset/DataFrame for a defined range of offsets.
    @@ -597,7 +597,7 @@ Note that the following Kafka params cannot be set and the Kafka source or sink - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. - **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use -DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. +DataFrame operations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ddba2f0d942e..2ef5d3168a87b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,7 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. @@ -1121,7 +1121,7 @@ Let’s discuss the different types of supported stream-stream joins and how to ##### Inner Joins with optional Watermarking Inner joins on any kind of columns along with any kind of join conditions are supported. However, as the stream runs, the size of streaming state will keep growing indefinitely as -*all* past input must be saved as the any new input can match with any input from the past. +*all* past input must be saved as any new input can match with any input from the past. To avoid unbounded state, you have to define additional join conditions such that indefinitely old inputs cannot match with future inputs and therefore can be cleared from the state. In other words, you will have to do the following additional steps in the join. @@ -1839,7 +1839,7 @@ aggDF \ .format("console") \ .start() -# Have all the aggregates in an in memory table. The query name will be the table name +# Have all the aggregates in an in-memory table. The query name will be the table name aggDF \ .writeStream \ .queryName("aggregates") \ diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 0473ab73a5e6c..a3643bf0838a1 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -5,7 +5,7 @@ title: Submitting Applications The `spark-submit` script in Spark's `bin` directory is used to launch applications on a cluster. It can use all of Spark's supported [cluster managers](cluster-overview.html#cluster-manager-types) -through a uniform interface so you don't have to configure your application specially for each one. +through a uniform interface so you don't have to configure your application especially for each one. # Bundling Your Application's Dependencies If your code depends on other projects, you will need to package them alongside @@ -58,7 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +the drivers and the executors. Currently, the standalone mode does not support cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, @@ -68,7 +68,7 @@ There are a few options available that are specific to the [cluster manager](cluster-overview.html#cluster-manager-types) that is being used. For example, with a [Spark standalone cluster](spark-standalone.html) with `cluster` deploy mode, you can also specify `--supervise` to make sure that the driver is automatically restarted if it -fails with non-zero exit code. To enumerate all such options available to `spark-submit`, +fails with a non-zero exit code. To enumerate all such options available to `spark-submit`, run it with `--help`. Here are a few examples of common options: {% highlight bash %} @@ -192,7 +192,7 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included in the driver and executor classpaths. Directory expansion does not work with `--jars`. Spark uses the following URL scheme to allow different strategies for disseminating jars: From 00d169156d4b1c91d2bcfd788b254b03c509dc41 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 14:49:49 -0800 Subject: [PATCH 712/712] [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing What changes were proposed in this pull request? Pass ‘spark.sql.parquet.compression.codec’ value to ‘parquet.compression’. Pass ‘spark.sql.orc.compression.codec’ value to ‘orc.compress’. How was this patch tested? Add test. Note: This is the same issue mentioned in #19218 . That branch was deleted mistakenly, so make a new pr instead. gatorsmile maropu dongjoon-hyun discipleforteen Author: fjh100456 Author: Takeshi Yamamuro Author: Wenchen Fan Author: gatorsmile Author: Yinan Li Author: Marcelo Vanzin Author: Juliusz Sompolski Author: Felix Cheung Author: jerryshao Author: Li Jin Author: Gera Shegalov Author: chetkhatri Author: Joseph K. Bradley Author: Bago Amirbekian Author: Xianjin YE Author: Bruce Robbins Author: zuotingbing Author: Kent Yao Author: hyukjinkwon Author: Adrian Ionescu Closes #20087 from fjh100456/HiveTableWriting. --- .../datasources/orc/OrcOptions.scala | 2 + .../datasources/parquet/ParquetOptions.scala | 6 +- .../sql/hive/execution/HiveOptions.scala | 22 ++ .../sql/hive/execution/SaveAsHiveFile.scala | 20 +- .../sql/hive/CompressionCodecSuite.scala | 353 ++++++++++++++++++ 5 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a525..0ad3862f6cf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index ef67ea7d17cea..f36a89a4c3c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -82,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9c..802ddafdbee4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..e484356906e87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 0000000000000..d10a6f25c64fc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +}